diff --git a/Cargo.toml b/Cargo.toml index ddc3108..0ee84cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ jrd = "0.1" tracing = "0.1" tracing-subscriber = "0.3" clap = { version = "4.5", features = ["derive"] } +futures = "0.3" tokio = { version = "1.35", features = ["full"] } # TODO slim this down sea-orm = { version = "0.12", features = ["macros", "sqlx-sqlite", "runtime-tokio-rustls"] } reqwest = { version = "0.12", features = ["json"] } diff --git a/src/main.rs b/src/main.rs index 6bc592d..4f059c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,7 +72,23 @@ enum CliCommand { actor: String, #[arg(long, default_value_t = false)] + /// instead of sending a follow request, send an accept accept: bool + }, + + /// run db maintenance tasks + Fix { + #[arg(long, default_value_t = false)] + /// fix likes counts for posts + likes: bool, + + #[arg(long, default_value_t = false)] + /// fix shares counts for posts + shares: bool, + + #[arg(long, default_value_t = false)] + /// fix replies counts for posts + replies: bool, } } @@ -147,6 +163,11 @@ async fn main() { .expect("could not dispatch relay activity"); }, + CliCommand::Fix { likes, shares, replies } => + fix(db, likes, shares, replies) + .await + .expect("failed running fix task"), + CliCommand::Serve => { let ctx = server::Context::new(db, args.domain) .await.expect("failed creating server context"); @@ -211,3 +232,73 @@ async fn fetch(db: sea_orm::DatabaseConnection, domain: String, uri: String, sav Ok(()) } + +async fn fix(db: sea_orm::DatabaseConnection, likes: bool, shares: bool, replies: bool) -> crate::Result<()> { + use futures::TryStreamExt; + + if likes { + tracing::info!("fixing likes..."); + let mut store = std::collections::HashMap::new(); + let mut stream = model::like::Entity::find().stream(&db).await?; + while let Some(like) = stream.try_next().await? { + store.insert(like.likes.clone(), store.get(&like.likes).unwrap_or(&0) + 1); + } + + for (k, v) in store { + let m = model::object::ActiveModel { + id: sea_orm::Set(k), + likes: sea_orm::Set(v), + ..Default::default() + }; + model::object::Entity::update(m) + .exec(&db) + .await?; + } + } + + if shares { + tracing::info!("fixing shares..."); + let mut store = std::collections::HashMap::new(); + let mut stream = model::share::Entity::find().stream(&db).await?; + while let Some(share) = stream.try_next().await? { + store.insert(share.shares.clone(), store.get(&share.shares).unwrap_or(&0) + 1); + } + + for (k, v) in store { + let m = model::object::ActiveModel { + id: sea_orm::Set(k), + shares: sea_orm::Set(v), + ..Default::default() + }; + model::object::Entity::update(m) + .exec(&db) + .await?; + } + } + + if replies { + tracing::info!("fixing replies..."); + let mut store = std::collections::HashMap::new(); + let mut stream = model::object::Entity::find().stream(&db).await?; + while let Some(object) = stream.try_next().await? { + if let Some(reply) = object.in_reply_to { + let before = store.get(&reply).unwrap_or(&0); + store.insert(reply, before + 1); + } + } + + for (k, v) in store { + let m = model::object::ActiveModel { + id: sea_orm::Set(k), + comments: sea_orm::Set(v), + ..Default::default() + }; + model::object::Entity::update(m) + .exec(&db) + .await?; + } + } + + tracing::info!("done running fix tasks"); + Ok(()) +}