diff --git a/src/activitypub/activity.rs b/src/activitypub/activity.rs index b71a4dc1..207e1faa 100644 --- a/src/activitypub/activity.rs +++ b/src/activitypub/activity.rs @@ -1,14 +1,12 @@ -use std::{ops::Deref, sync::Arc}; - use axum::{extract::{Path, State}, http::StatusCode, Json}; -use sea_orm::{DatabaseConnection, EntityTrait}; +use sea_orm::EntityTrait; -use crate::{activitystream::Base, model::activity}; +use crate::{activitystream::Base, model::activity, server::Context}; -pub async fn view(State(db) : State>, Path(id): Path) -> Result, StatusCode> { +pub async fn view(State(ctx) : State, Path(id): Path) -> Result, StatusCode> { let uri = format!("http://localhost:3000/activities/{id}"); - match activity::Entity::find_by_id(uri).one(db.deref()).await { + match activity::Entity::find_by_id(uri).one(ctx.db()).await { Ok(Some(activity)) => Ok(Json(activity.underlying_json_object())), Ok(None) => Err(StatusCode::NOT_FOUND), Err(e) => { diff --git a/src/activitypub/mod.rs b/src/activitypub/mod.rs index e51fba01..f8d347c0 100644 --- a/src/activitypub/mod.rs +++ b/src/activitypub/mod.rs @@ -2,22 +2,13 @@ pub mod user; pub mod object; pub mod activity; -use std::{ops::Deref, sync::Arc}; use axum::{extract::State, http::StatusCode, Json}; -use sea_orm::{DatabaseConnection, EntityTrait, IntoActiveModel}; +use sea_orm::{EntityTrait, IntoActiveModel}; -use crate::{activitystream::{object::{ObjectType, activity::{Activity, ActivityType}}, Base, BaseType, Node}, model}; +use crate::{activitystream::{object::{activity::{Activity, ActivityType}, ObjectType}, Base, BaseType, Node}, model, server::Context}; -pub fn uri_id(entity: &str, id: String) -> String { - if id.starts_with("http") { id } else { format!("http://localhost:3000/{entity}/{id}") } -} - -pub fn id_uri(id: &str) -> &str { - id.split('/').last().unwrap_or("") -} - #[derive(Debug, serde::Deserialize)] // TODO i don't really like how pleroma/mastodon do it actually, maybe change this? pub struct Page { @@ -25,7 +16,7 @@ pub struct Page { pub max_id: Option, } -pub async fn inbox(State(db) : State>, Json(object): Json) -> Result, StatusCode> { +pub async fn inbox(State(ctx) : State, Json(object): Json) -> Result, StatusCode> { match object.base_type() { None => { Err(StatusCode::BAD_REQUEST) }, Some(BaseType::Link(_x)) => Err(StatusCode::UNPROCESSABLE_ENTITY), // we could but not yet @@ -44,10 +35,10 @@ pub async fn inbox(State(db) : State>, Json(object): Jso return Err(StatusCode::UNPROCESSABLE_ENTITY); }; model::object::Entity::insert(obj_entity.into_active_model()) - .exec(db.deref()) + .exec(ctx.db()) .await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; model::activity::Entity::insert(activity_entity.into_active_model()) - .exec(db.deref()) + .exec(ctx.db()) .await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(serde_json::Value::Null)) // TODO hmmmmmmmmmmm not the best value to return.... }, @@ -56,6 +47,6 @@ pub async fn inbox(State(db) : State>, Json(object): Jso } } -pub async fn outbox(State(_db): State>) -> Result, StatusCode> { +pub async fn outbox(State(_db): State) -> Result, StatusCode> { todo!() } diff --git a/src/activitypub/object.rs b/src/activitypub/object.rs index f735c97c..1fe23445 100644 --- a/src/activitypub/object.rs +++ b/src/activitypub/object.rs @@ -1,14 +1,11 @@ -use std::{ops::Deref, sync::Arc}; - use axum::{extract::{Path, State}, http::StatusCode, Json}; -use sea_orm::{DatabaseConnection, EntityTrait}; +use sea_orm::EntityTrait; -use crate::{activitystream::Base, model::object}; +use crate::{activitystream::Base, model::object, server::Context}; -pub async fn view(State(db) : State>, Path(id): Path) -> Result, StatusCode> { - let uri = format!("http://localhost:3000/objects/{id}"); - match object::Entity::find_by_id(uri).one(db.deref()).await { +pub async fn view(State(ctx) : State, Path(id): Path) -> Result, StatusCode> { + match object::Entity::find_by_id(ctx.uri("objects", id)).one(ctx.db()).await { Ok(Some(object)) => Ok(Json(object.underlying_json_object())), Ok(None) => Err(StatusCode::NOT_FOUND), Err(e) => { diff --git a/src/activitypub/user.rs b/src/activitypub/user.rs index 17293640..966c8063 100644 --- a/src/activitypub/user.rs +++ b/src/activitypub/user.rs @@ -3,14 +3,14 @@ use std::sync::Arc; use axum::{extract::{Path, Query, State}, http::StatusCode, Json}; use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, IntoActiveModel, Order, QueryFilter, QueryOrder, QuerySelect}; -use crate::{activitystream::{self, object::{activity::{Activity, ActivityType}, collection::{page::CollectionPageMut, CollectionMut, CollectionType}, ObjectType}, Base, BaseMut, BaseType, Node}, model::{self, activity, object, user}}; +use crate::{activitystream::{self, object::{activity::{Activity, ActivityType}, collection::{page::CollectionPageMut, CollectionMut, CollectionType}, ObjectType}, Base, BaseMut, BaseType, Node}, model::{self, activity, object, user}, server::Context}; pub async fn list(State(_db) : State>) -> Result, StatusCode> { todo!() } -pub async fn view(State(db) : State>, Path(id): Path) -> Result, StatusCode> { - match user::Entity::find_by_id(super::uri_id("users", id)).one(&*db).await { +pub async fn view(State(ctx) : State, Path(id): Path) -> Result, StatusCode> { + match user::Entity::find_by_id(ctx.uri("users", id)).one(ctx.db()).await { Ok(Some(user)) => Ok(Json(user.underlying_json_object())), Ok(None) => Err(StatusCode::NOT_FOUND), Err(e) => { @@ -21,7 +21,7 @@ pub async fn view(State(db) : State>, Path(id): Path>, + State(ctx): State, Path(id): Path, Query(page): Query, ) -> Result, StatusCode> { @@ -29,8 +29,8 @@ pub async fn outbox( // find requested recent post, to filter based on its date (use now() as fallback) let before = if let Some(before) = page.max_id { - match model::activity::Entity::find_by_id(super::uri_id("activities", before)) - .one(&*db).await + match model::activity::Entity::find_by_id(ctx.uri("activities", before)) + .one(ctx.db()).await { Ok(None) => return Err(StatusCode::NOT_FOUND), Ok(Some(x)) => x.published, @@ -45,11 +45,11 @@ pub async fn outbox( .filter(Condition::all().add(activity::Column::Published.lt(before))) .order_by(activity::Column::Published, Order::Desc) .limit(20) // TODO allow customizing, with boundaries - .all(&*db).await + .all(ctx.db()).await { Err(_e) => Err(StatusCode::INTERNAL_SERVER_ERROR), Ok(items) => { - let next = super::id_uri(&items.last().unwrap().id).to_string(); + let next = ctx.id(items.last().map(|x| x.id.as_str()).unwrap_or("").to_string()); let items = items .into_iter() .map(|i| i.underlying_json_object()) @@ -76,7 +76,7 @@ pub async fn outbox( } pub async fn inbox( - State(db): State>, + State(ctx): State, Path(_id): Path, Json(object): Json ) -> Result, StatusCode> { @@ -98,10 +98,10 @@ pub async fn inbox( return Err(StatusCode::UNPROCESSABLE_ENTITY); }; object::Entity::insert(obj_entity.into_active_model()) - .exec(&*db) + .exec(ctx.db()) .await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; activity::Entity::insert(activity_entity.into_active_model()) - .exec(&*db) + .exec(ctx.db()) .await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(serde_json::Value::Null)) // TODO hmmmmmmmmmmm not the best value to return.... }, diff --git a/src/main.rs b/src/main.rs index c2d14069..ac033d19 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,10 @@ struct CliArgs { /// database connection uri database: String, + #[arg(short, long, default_value = "http://localhost:3000")] + /// instance base domain, for AP ids + domain: String, + #[arg(long, default_value_t=false)] /// run with debug level tracing debug: bool, @@ -66,13 +70,13 @@ async fn main() { .await.expect("error connecting to db"); match args.command { - CliCommand::Serve => server::serve(db) + CliCommand::Serve => server::serve(db, args.domain) .await, CliCommand::Migrate => migrations::Migrator::up(&db, None) .await.expect("error applying migrations"), - CliCommand::Faker => model::faker(&db) + CliCommand::Faker => model::faker(&db, args.domain) .await.expect("error creating fake entities"), CliCommand::Fetch { uri, save } => fetch(&db, &uri, save) diff --git a/src/model/mod.rs b/src/model/mod.rs index 25f3d16a..c9593356 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -6,11 +6,11 @@ pub mod activity; #[error("missing required field: '{0}'")] pub struct FieldError(pub &'static str); -pub async fn faker(db: &sea_orm::DatabaseConnection) -> Result<(), sea_orm::DbErr> { +pub async fn faker(db: &sea_orm::DatabaseConnection, domain: String) -> Result<(), sea_orm::DbErr> { use sea_orm::EntityTrait; user::Entity::insert(user::ActiveModel { - id: sea_orm::Set("http://localhost:3000/users/root".into()), + id: sea_orm::Set(format!("{domain}/users/root")), name: sea_orm::Set("root".into()), actor_type: sea_orm::Set(super::activitystream::object::actor::ActorType::Person), }).exec(db).await?; @@ -19,20 +19,20 @@ pub async fn faker(db: &sea_orm::DatabaseConnection) -> Result<(), sea_orm::DbEr let oid = uuid::Uuid::new_v4(); let aid = uuid::Uuid::new_v4(); object::Entity::insert(object::ActiveModel { - id: sea_orm::Set(format!("http://localhost:3000/objects/{oid}")), + id: sea_orm::Set(format!("{domain}/objects/{oid}")), name: sea_orm::Set(None), object_type: sea_orm::Set(crate::activitystream::object::ObjectType::Note), - attributed_to: sea_orm::Set(Some("http://localhost:3000/users/root".into())), + attributed_to: sea_orm::Set(Some(format!("{domain}/users/root"))), summary: sea_orm::Set(None), - content: sea_orm::Set(Some(format!("Hello world! {i}"))), + content: sea_orm::Set(Some(format!("[{i}] Tic(k). Quasiparticle of intensive multiplicity. Tics (or ticks) are intrinsically several components of autonomously numbering anorganic populations, propagating by contagion between segmentary divisions in the order of nature. Ticks - as nonqualitative differentially-decomposable counting marks - each designate a multitude comprehended as a singular variation in tic(k)-density."))), published: sea_orm::Set(chrono::Utc::now() - std::time::Duration::from_secs(60*i)), }).exec(db).await?; activity::Entity::insert(activity::ActiveModel { - id: sea_orm::Set(format!("http://localhost:3000/activities/{aid}")), + id: sea_orm::Set(format!("{domain}/activities/{aid}")), activity_type: sea_orm::Set(crate::activitystream::object::activity::ActivityType::Create), - actor: sea_orm::Set("http://localhost:3000/users/root".into()), - object: sea_orm::Set(Some(format!("http://localhost:3000/objects/{oid}"))), + actor: sea_orm::Set(format!("{domain}/users/root")), + object: sea_orm::Set(Some(format!("{domain}/objects/{oid}"))), target: sea_orm::Set(None), published: sea_orm::Set(chrono::Utc::now() - std::time::Duration::from_secs(60*i)), }).exec(db).await?; diff --git a/src/server.rs b/src/server.rs index a5328a59..0af3d13e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,7 +4,47 @@ use axum::{routing::{get, post}, Router}; use sea_orm::DatabaseConnection; use crate::activitypub as ap; -pub async fn serve(db: DatabaseConnection) { +#[derive(Clone)] +pub struct Context(Arc); +struct ContextInner { + db: DatabaseConnection, + domain: String, +} +impl Context { + pub fn new(db: DatabaseConnection, mut domain: String) -> Self { + if !domain.starts_with("http") { + domain = format!("https://{domain}"); + } + if domain.ends_with('/') { + domain.replace_range(domain.len()-1.., ""); + } + Context(Arc::new(ContextInner { db, domain })) + } + + pub fn db(&self) -> &DatabaseConnection { + &self.0.db + } + + pub fn uri(&self, entity: &str, id: String) -> String { + if id.starts_with("http") { id } else { + format!("{}/{}/{}", self.0.domain, entity, id) + } + } + + pub fn id(&self, id: String) -> String { + if id.starts_with(&self.0.domain) { + let mut out = id.replace(&self.0.domain, ""); + if out.ends_with('/') { + out.replace_range(out.len()-1.., ""); + } + out + } else { + id + } + } +} + +pub async fn serve(db: DatabaseConnection, domain: String) { // build our application with a single route let app = Router::new() // core server inbox/outbox, maybe for feeds? TODO do we need these? @@ -17,7 +57,7 @@ pub async fn serve(db: DatabaseConnection) { // specific object routes .route("/activities/:id", get(ap::activity::view)) .route("/objects/:id", get(ap::object::view)) - .with_state(Arc::new(db)); + .with_state(Context::new(db, domain)); // run our app with hyper, listening globally on port 3000 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();