diff --git a/src/main.rs b/src/main.rs index 6b80040..e10d732 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,10 @@ -use std::net::ToSocketAddrs; +#![allow(clippy::use_self)] + +use std::{net::ToSocketAddrs, sync::Arc}; use clap::Parser; -use axum::{extract::Query, routing::get, Json, Router}; +use axum::{extract::{Query, State}, routing::get, Json, Router}; use session::Session; mod tcp; @@ -10,8 +12,39 @@ mod udp; mod session; mod model; +/// small http api providing mumble stats #[derive(Parser)] struct CliArgs { + /// mumble server to join and monitor + server: String, + + /// mumble server port + #[arg(short, long, default_value_t = 64738)] + port: u16, + + /// bot username on mumble + #[arg(short, long, default_value = ".mumble-stats-api")] + username: String, + + /// optional server password + #[arg(short, long)] + password: Option, + + /// host to bind on + #[arg(long = "bind-host", short = 'H', default_value = "127.0.0.1")] + bind_host: String, + + /// port to bind on + #[arg(long = "bind-port", short = 'P', default_value_t = 57039)] + bind_port: u16, + + /// disable arbitrary udp ping endpoing + #[arg(long, default_value_t = false)] + no_ping: bool, + + /// disable arbitrary server join and peek + #[arg(long, default_value_t = false)] + no_peek: bool, } #[tokio::main] @@ -19,24 +52,45 @@ async fn main() { // initialize tracing tracing_subscriber::fmt::init(); + let args = CliArgs::parse(); + + let session = Session::connect( + &args.server, + Some(args.port), + Some(args.username), + args.password, + ).await.expect("could not connect to mumble server"); + // build our application with a route - let app = Router::new() - .route("/ping", get(ping_server)) - .route("/users", get(list_server_users)); + let mut app = Router::new(); + + if !args.no_ping { + app = app.route("/ping", get(ping_server)); + } + + if !args.no_peek { + app = app.route("/peek", get(peek_server)); + } + + let app = app + .route("/info", get(server_info)) + .route("/users", get(server_users)) + .with_state(session); tracing::info!("serving mumble-stats-api"); - let listener = tokio::net::TcpListener::bind("127.0.0.1:57039").await + let listener = tokio::net::TcpListener::bind((args.bind_host, args.bind_port)).await .expect("could not bind on requested addr"); axum::serve(listener, app).await .expect("could not serve axum app"); } -async fn list_server_users(Query(options): Query) -> Result>, String> { - match Session::users(&options.host, options.port, options.username, options.password).await { - Ok(users) => Ok(Json(users)), - Err(e) => Err(format!("could not list users: {e}")), - } +async fn server_info(State(session): State>) -> Result, String> { + Ok(Json(session.host())) +} + +async fn server_users(State(session): State>) -> Result>, String> { + Ok(Json(session.users().await)) } async fn ping_server(Query(options): Query) -> Result, String> { @@ -53,3 +107,15 @@ async fn ping_server(Query(options): Query) -> Result) -> Result>, String> { + match Session::connect( + &options.host, options.port, options.username, options.password + ).await { + Err(e) => Err(format!("could not connect to server: {e}")), + Ok(s) => { + s.ready().await; + Ok(Json(s.users().await)) + }, + } +} diff --git a/src/model.rs b/src/model.rs index 65622fc..6e2aa7b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -15,7 +15,7 @@ pub struct ServerInfo { } #[derive(serde::Deserialize)] -pub struct ListUsersOptions { +pub struct PeekOptions { pub host: String, pub port: Option, pub username: Option, @@ -23,7 +23,7 @@ pub struct ListUsersOptions { pub tokens: Option>, } -#[derive(Debug, serde::Serialize)] +#[derive(Debug, Clone, serde::Serialize)] pub struct User { /// Unique user session ID of the user whose state this is, may change on /// reconnect. @@ -44,7 +44,7 @@ pub struct User { pub properties: UserProperties, } -#[derive(Debug, serde::Serialize)] +#[derive(Debug, Clone, serde::Serialize)] pub struct UserProperties { /// True if the user is muted by admin. pub mute: bool, diff --git a/src/session.rs b/src/session.rs index 8d596b3..e05d869 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,11 +1,13 @@ -use std::{collections::HashMap, net::SocketAddr, sync::{atomic::AtomicBool, Arc}}; +use std::{borrow::Borrow, collections::HashMap, net::SocketAddr, sync::Arc}; use tokio::{net::UdpSocket, sync::{mpsc::{self, error::TrySendError}, watch, RwLock}}; use crate::{model::User, tcp::{control::ControlChannel, proto}, udp::proto::{PingPacket, PongPacket}}; +#[derive(Debug)] pub struct Session { - pub users: RwLock>, + users: RwLock>, + host: String, sync: watch::Receiver, drop: mpsc::Sender<()>, } @@ -50,49 +52,22 @@ impl Session { }) } - pub async fn users(host: &str, port: Option, username: Option, password: Option) -> std::io::Result> { - let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); - let mut channel = ControlChannel::new(host, port).await?; - let version = proto::Version { - version_v1: None, - version_v2: Some(281496485429248), - release: Some("1.5.517".into()), - os: None, - os_version: None, - }; - let authenticate = proto::Authenticate { - username: Some(username.clone()), - password, - tokens: Vec::new(), - celt_versions: Vec::new(), - opus: Some(true), - client_type: Some(1), - }; - - for pkt in [ - proto::Packet::Version(version), - proto::Packet::Authenticate(authenticate), - ] { - channel.send(pkt).await?; - } - - let mut users = Vec::new(); - + pub async fn ready(&self) { + let mut sync = self.sync.clone(); loop { - match channel.recv().await? { - // Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message), - // Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name), - proto::Packet::ServerSync(_sync) => break Ok(users), - proto::Packet::UserState(user) => { - if user.name.as_ref().is_some_and(|n| n != &username) { - users.push(user.into()); - } - }, - pkt => tracing::debug!("ignoring packet {:#?}", pkt), - } + if *sync.borrow() { break } + sync.changed().await.unwrap(); } } + pub async fn users(&self) -> Vec { + self.users.read().await.borrow().values().cloned().collect() + } + + pub fn host(&self) -> String { + self.host.to_string() + } + pub async fn connect(host: &str, port: Option, username: Option, password: Option) -> std::io::Result> { let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); let mut channel = ControlChannel::new(host, port).await?; @@ -126,6 +101,7 @@ impl Session { users : RwLock::new(HashMap::new()), sync: ready, drop: tx, + host: host.to_string(), }); let session = s.clone(); @@ -142,6 +118,7 @@ impl Session { // Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name), Ok(proto::Packet::ServerSync(_sync)) => ready_tx.send(true).unwrap(), Ok(proto::Packet::UserState(user)) => { + tracing::info!("user state: {:#?}", user); if user.name.as_ref().is_some_and(|n| n != &username) { session.users.write().await.insert(user.user_id(), User::from(user)); } diff --git a/src/tcp/control.rs b/src/tcp/control.rs index ca3defd..16e8af2 100644 --- a/src/tcp/control.rs +++ b/src/tcp/control.rs @@ -1,4 +1,4 @@ -use std::net::ToSocketAddrs; +use std::net::{SocketAddr, ToSocketAddrs}; use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream}; use tokio_native_tls::TlsStream; @@ -10,9 +10,8 @@ pub struct ControlChannel { impl ControlChannel { pub async fn new(host: &str, port: Option) -> std::io::Result { let addr = (host, port.unwrap_or(64738)).to_socket_addrs()? - .filter(|a| a.is_ipv4()) - .next() - .ok_or(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable))?; + .find(SocketAddr::is_ipv4) + .ok_or_else(|| std::io::Error::from(std::io::ErrorKind::AddrNotAvailable))?; let socket = TcpStream::connect(addr).await?; // use native_tls builder and then .into() so we can pass options to the builder let connector : tokio_native_tls::TlsConnector = native_tls::TlsConnector::builder() @@ -37,6 +36,6 @@ impl ControlChannel { let size = self.stream.read_u32().await?; let mut buffer = vec![0u8; size as usize]; self.stream.read_exact(&mut buffer).await?; - Ok(super::proto::Packet::decode(id, &buffer)?) + super::proto::Packet::decode(id, &buffer) } }