diff --git a/src/main.rs b/src/main.rs index 656023e..d55a03b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,12 +55,12 @@ async fn main() { let args = CliArgs::parse(); - let session = Session::connect( + let session = Arc::new(Session::new( &args.server, Some(args.port), Some(args.username), args.password, - ).await.expect("could not connect to mumble server"); + )); // build our application with a route let mut app = Router::new(); @@ -69,9 +69,9 @@ async fn main() { app = app.route("/ping", get(ping_server)); } - if !args.no_peek { - app = app.route("/peek", get(peek_server)); - } + // if !args.no_peek { + // app = app.route("/peek", get(peek_server)); + // } let app = app .route("/info", get(server_info)) @@ -102,11 +102,13 @@ async fn server_ws(ws: WebSocketUpgrade, State(session): State>) -> async fn handle_ws(mut socket: WebSocket, mut sub: broadcast::Receiver) { while let Ok(event) = sub.recv().await { - match event { + if let Err(e) = match event { session::SessionEvent::AddUser(user) => - socket.send(Message::Text(serde_json::to_string(&user).expect("could not serialize user"))).await.unwrap(), + socket.send(Message::Text(serde_json::to_string(&user).expect("could not serialize user"))).await, session::SessionEvent::RemoveUser(id) => - socket.send(Message::Text(format!("{{\"remove\":{id}}}"))).await.unwrap(), + socket.send(Message::Text(format!("{{\"remove\":{id}}}"))).await, + } { + tracing::debug!("websocket disconnected: {e}"); } } } @@ -126,14 +128,14 @@ async fn ping_server(Query(options): Query) -> Result) -> Result>, String> { - match Session::connect( - &options.host, options.port, options.username, options.password - ).await { - Err(e) => Err(format!("could not connect to server: {e}")), - Ok(s) => { - s.ready().await; - Ok(Json(s.users().await)) - }, - } -} +// async fn peek_server(Query(options): Query) -> Result>, String> { +// match Session::new( +// &options.host, options.port, options.username, options.password +// ).await { +// Err(e) => Err(format!("could not connect to server: {e}")), +// Ok(s) => { +// s.ready().await; +// Ok(Json(s.users().await)) +// }, +// } +// } diff --git a/src/session.rs b/src/session.rs index 2ee12a4..54a1d16 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,17 +1,24 @@ -use std::{borrow::Borrow, collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{borrow::Borrow, collections::HashMap, net::SocketAddr, sync::{atomic::AtomicBool, Arc}}; -use tokio::{net::UdpSocket, sync::{broadcast, mpsc::{self, error::TrySendError}, watch, RwLock}}; +use tokio::{net::UdpSocket, sync::{broadcast, RwLock}}; use crate::{model::User, tcp::{control::ControlChannel, proto}, udp::proto::{PingPacket, PongPacket}}; #[derive(Debug)] pub struct Session { - users: RwLock>, - username: String, - host: String, - sync: watch::Receiver, - drop: mpsc::Sender<()>, - events: broadcast::Sender, + options: Arc, + users: Arc>>, + // sync: watch::Receiver, + run: Arc, + events: Arc>, +} + +#[derive(Debug, Clone, Default)] +pub struct SessionOptions { + pub username: String, + pub password: Option, + pub host: String, + pub port: u16, } #[derive(Debug, Clone)] @@ -22,11 +29,7 @@ pub enum SessionEvent { impl Drop for Session { fn drop(&mut self) { - match self.drop.try_send(()) { - Ok(()) => {}, - Err(TrySendError::Full(())) => tracing::warn!("session stop channel full"), - Err(TrySendError::Closed(())) => tracing::warn!("session stop channel already closed"), - } + self.run.store(false, std::sync::atomic::Ordering::Relaxed); } } @@ -60,115 +63,141 @@ impl Session { }) } - pub async fn ready(&self) { - let mut sync = self.sync.clone(); - loop { - if *sync.borrow() { break } - sync.changed().await.unwrap(); - } - } + // pub async fn ready(&self) { + // let mut sync = self.sync.clone(); + // loop { + // if *sync.borrow() { break } + // sync.changed().await.unwrap(); + // } + // } pub async fn users(&self) -> Vec { self.users.read().await .borrow() .values() - .filter(|u| u.name != self.username) + .filter(|u| u.name != self.options.username) .cloned() .collect() } pub fn host(&self) -> String { - self.host.to_string() + self.options.host.to_string() } pub fn events(&self) -> broadcast::Receiver { self.events.subscribe() } - pub async fn connect(host: &str, port: Option, username: Option, password: Option) -> std::io::Result> { - let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); - let channel = Arc::new(ControlChannel::new(host, port).await?); - let version = proto::Version { - version_v1: None, - version_v2: Some(281496485429248), - release: Some("1.5.517".into()), - os: None, - os_version: None, - }; - let authenticate = proto::Authenticate { - username: Some(username.clone()), - password, - tokens: Vec::new(), - celt_versions: Vec::new(), - opus: Some(true), - client_type: Some(1), - }; + // async fn connect(&self) -> std::io::Result<()> { + // Self::connect_session(self.options.clone(), self.run.clone(), self.users.clone(), self.events.clone()).await + // } + + async fn connect_session( + options: Arc, + run: Arc, + users: Arc>>, + events: Arc>, + ) -> std::io::Result<()> { + let channel = Arc::new(ControlChannel::new(&options.host, Some(options.port)).await?); for pkt in [ - proto::Packet::Version(version), - proto::Packet::Authenticate(authenticate), + proto::Packet::Version(proto::Version { + version_v1: None, + version_v2: Some(281496485429248), + release: Some("1.5.517".into()), + os: None, + os_version: None, + }), + proto::Packet::Authenticate(proto::Authenticate { + username: Some(options.username.clone()), + password: options.password.clone(), + tokens: Vec::new(), + celt_versions: Vec::new(), + opus: Some(true), + client_type: Some(1), + }), ] { channel.send(pkt).await?; } - let (drop, mut stop) = mpsc::channel(1); - let (ready, sync) = watch::channel(false); - let (events, _) = broadcast::channel(64); + let mut tasks = tokio::task::JoinSet::new(); - let s = Arc::new(Session { - drop, sync, events, - username: username.clone(), - users : RwLock::new(HashMap::new()), - host: host.to_string(), - }); - - let session = s.clone(); - let chan = channel.clone(); - tokio::spawn(async move { - loop { - match stop.try_recv() { - Ok(()) => break, - Err(mpsc::error::TryRecvError::Empty) => {}, - Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."), - } - match chan.recv().await { - Err(e) => break tracing::warn!("disconnected from server: {}", e), - // Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message), - // Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name), - Ok(proto::Packet::UserRemove(user)) => { - tracing::info!("remove user: {:?}", user); - session.users.write().await.remove(&user.session); - let _ = session.events.send(SessionEvent::RemoveUser(user.session)); + let _channel = channel.clone(); + let _run = run.clone(); + tasks.spawn(async move { + while _run.load(std::sync::atomic::Ordering::Relaxed) { + match _channel.recv().await { + Err(e) => { + tracing::warn!("disconnected from server: {}", e); + break; }, - Ok(proto::Packet::ServerSync(_sync)) => { - tracing::info!("synched: {:?}", _sync); - ready.send(true).unwrap(); + Ok(proto::Packet::UserRemove(user)) => { + tracing::debug!("removing user: {:?}", user); + users.write().await.remove(&user.session); + let _ = events.send(SessionEvent::RemoveUser(user.session)); }, Ok(proto::Packet::UserState(user)) => { - tracing::info!("user state: {:?}", user); - let mut users = session.users.write().await; + tracing::debug!("updating user state: {:?}", user); + let mut users = users.write().await; let id = user.session(); match users.get_mut(&id) { Some(u) => u.update(user), None => { users.insert(user.session(), User::from(user)); }, } - let _ = session.events.send( + let _ = events.send( SessionEvent::AddUser(users.get(&id).cloned().expect("just inserted")) ); // if it fails nobody is listening }, - Ok(pkt) => tracing::info!("ignoring packet {:?}", pkt), + Ok(pkt) => tracing::debug!("ignoring packet {:?}", pkt), + } + } + users.write().await.clear(); + }); + + tasks.spawn(async move { + while run.load(std::sync::atomic::Ordering::Relaxed) { + tokio::time::sleep(std::time::Duration::from_secs(20)).await; + if let Err(e) = channel.send(proto::Packet::Ping(proto::Ping::default())).await { + tracing::warn!("could not send ping: {e}"); + break; } } }); - let chan = channel.clone(); + while let Some(res) = tasks.join_next().await { res? } + + Ok(()) + } + + pub fn new(host: &str, port: Option, username: Option, password: Option) -> Self { + let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); + let (events, _) = broadcast::channel(64); + + let s = Session { + events: Arc::new(events), + users : Arc::new(RwLock::new(HashMap::new())), + run: Arc::new(AtomicBool::new(true)), + options: Arc::new(SessionOptions { + username, password, + host: host.to_string(), + port: port.unwrap_or(64738), + }), + }; + + let options = s.options.clone(); + let run = s.run.clone(); + let users = s.users.clone(); + let events = s.events.clone(); tokio::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(20)).await; - chan.send(proto::Packet::Ping(proto::Ping::default())).await.unwrap(); + while run.load(std::sync::atomic::Ordering::Relaxed) { + if let Err(e) = Self::connect_session(options.clone(), run.clone(), users.clone(), events.clone()).await { + tracing::error!("could not connect to mumble: {e}"); + } + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + tracing::info!("attempting to reconnect..."); } }); - Ok(s) + s } }