upub/src/server/auth.rs

161 lines
4.5 KiB
Rust
Raw Normal View History

use axum::{extract::{FromRef, FromRequestParts}, http::{header, request::Parts}};
2024-04-13 01:49:23 +02:00
use base64::Engine;
use openssl::{hash::MessageDigest, pkey::PKey, sign::Verifier};
2024-03-25 01:58:30 +01:00
use sea_orm::{ColumnTrait, Condition, EntityTrait, QueryFilter};
2024-04-13 01:49:23 +02:00
use crate::{errors::UpubError, model, server::Context};
2024-03-25 01:58:30 +01:00
#[derive(Debug, Clone)]
pub enum Identity {
Anonymous,
Local(String),
Remote(String),
2024-03-25 01:58:30 +01:00
}
impl Identity {
pub fn filter_condition(&self) -> Condition {
let base_cond = Condition::any().add(model::addressing::Column::Actor.eq(apb::target::PUBLIC));
match self {
Identity::Anonymous => base_cond,
Identity::Local(uid) => base_cond.add(model::addressing::Column::Actor.eq(uid)),
Identity::Remote(server) => base_cond.add(model::addressing::Column::Server.eq(server)),
// TODO should we allow all users on same server to see? or just specific user??
}
}
}
2024-03-25 01:58:30 +01:00
pub struct AuthIdentity(pub Identity);
#[axum::async_trait]
impl<S> FromRequestParts<S> for AuthIdentity
where
Context: FromRef<S>,
S: Send + Sync,
{
2024-04-13 01:49:23 +02:00
type Rejection = UpubError;
2024-03-25 01:58:30 +01:00
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let ctx = Context::from_ref(state);
let mut identity = Identity::Anonymous;
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.map(|v| v.to_str().unwrap_or(""))
.unwrap_or("");
if auth_header.starts_with("Bearer ") {
match model::session::Entity::find_by_id(auth_header.replace("Bearer ", ""))
.filter(Condition::all().add(model::session::Column::Expires.gt(chrono::Utc::now())))
.one(ctx.db())
.await
{
Ok(Some(x)) => identity = Identity::Local(x.actor),
2024-04-13 01:49:23 +02:00
Ok(None) => return Err(UpubError::unauthorized()),
2024-03-25 01:58:30 +01:00
Err(e) => {
tracing::error!("failed querying user session: {e}");
2024-04-13 01:49:23 +02:00
return Err(UpubError::internal_server_error())
2024-03-25 01:58:30 +01:00
},
}
}
2024-04-13 01:49:23 +02:00
if let Some(sig) = parts
.headers
.get("Signature")
.map(|v| v.to_str().unwrap_or(""))
{
let http_signature = HttpSignature::parse(sig);
let user_id = http_signature.key_id.replace("#main-key", "");
match ctx.fetch().user(&user_id).await {
Ok(user) => {
let to_sign = http_signature.build_string(parts);
// TODO assert payload's digest is equal to signature's
match verify_control_text(&to_sign, &user.public_key, &http_signature.signature) {
Ok(true) => identity = Identity::Remote(Context::server(&user_id)),
Ok(false) => tracing::warn!("invalid signature"),
Err(e) => tracing::error!("error verifying signature: {e}"),
}
},
Err(e) => tracing::warn!("could not fetch user (won't verify): {e}"),
}
2024-04-13 01:49:23 +02:00
}
2024-03-25 01:58:30 +01:00
Ok(AuthIdentity(identity))
}
}
fn verify_control_text(txt: &str, key: &str, control: &str) -> crate::Result<bool> {
let pubkey = PKey::public_key_from_pem(key.as_bytes())?;
let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
verifier.update(txt.as_bytes())?;
Ok(verifier.verify(&base64::prelude::BASE64_URL_SAFE.decode(control).unwrap_or_default())?)
}
#[derive(Debug, Clone, Default)]
pub struct HttpSignature {
key_id: String,
algorithm: String,
headers: Vec<String>,
signature: String,
}
impl HttpSignature {
pub fn parse(header: &str) -> Self {
let mut sig = HttpSignature::default();
header.split(',')
.filter_map(|x| x.split_once('='))
.map(|(k, v)| (k, v.trim_end_matches('"').trim_matches('"')))
.for_each(|(k, v)| match k {
"keyId" => sig.key_id = v.to_string(),
"algorithm" => sig.algorithm = v.to_string(),
"signature" => sig.signature = v.to_string(),
"headers" => sig.headers = v.split(' ').map(|x| x.to_string()).collect(),
_ => tracing::warn!("unexpected field in http signature: '{k}=\"{v}\"'"),
});
sig
}
pub fn build_string(&self, parts: &Parts) -> String {
let mut out = Vec::new();
for header in self.headers.iter() {
match header.as_str() {
"(request-target)" => out.push(
format!("(request-target): {}", parts.uri.path_and_query().map(|x| x.as_str()).unwrap_or("/"))
),
// TODO other pseudo-headers,
_ => out.push(format!("{}: {}",
header.to_lowercase(),
parts.headers.get(header).map(|x| x.to_str().unwrap_or("")).unwrap_or("")
)),
}
}
out.join("\n")
}
pub fn digest(&self) -> MessageDigest {
match self.algorithm.as_str() {
"rsa-sha512" => MessageDigest::sha512(),
"rsa-sha384" => MessageDigest::sha384(),
"rsa-sha256" => MessageDigest::sha256(),
"rsa-sha1" => MessageDigest::sha1(),
_ => {
tracing::error!("unknown digest algorithm, trying with rsa-sha256");
MessageDigest::sha256()
}
}
}
}