feat: auto reconnect, also cleanup and refactor

temporarily disabled peek route because im lazy but may re-enable in the
future
This commit is contained in:
əlemi 2024-02-21 23:56:14 +01:00
parent e4b33c11c9
commit e2403688fe
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 131 additions and 100 deletions

View file

@ -55,12 +55,12 @@ async fn main() {
let args = CliArgs::parse(); let args = CliArgs::parse();
let session = Session::connect( let session = Arc::new(Session::new(
&args.server, &args.server,
Some(args.port), Some(args.port),
Some(args.username), Some(args.username),
args.password, args.password,
).await.expect("could not connect to mumble server"); ));
// build our application with a route // build our application with a route
let mut app = Router::new(); let mut app = Router::new();
@ -69,9 +69,9 @@ async fn main() {
app = app.route("/ping", get(ping_server)); app = app.route("/ping", get(ping_server));
} }
if !args.no_peek { // if !args.no_peek {
app = app.route("/peek", get(peek_server)); // app = app.route("/peek", get(peek_server));
} // }
let app = app let app = app
.route("/info", get(server_info)) .route("/info", get(server_info))
@ -102,11 +102,13 @@ async fn server_ws(ws: WebSocketUpgrade, State(session): State<Arc<Session>>) ->
async fn handle_ws(mut socket: WebSocket, mut sub: broadcast::Receiver<session::SessionEvent>) { async fn handle_ws(mut socket: WebSocket, mut sub: broadcast::Receiver<session::SessionEvent>) {
while let Ok(event) = sub.recv().await { while let Ok(event) = sub.recv().await {
match event { if let Err(e) = match event {
session::SessionEvent::AddUser(user) => session::SessionEvent::AddUser(user) =>
socket.send(Message::Text(serde_json::to_string(&user).expect("could not serialize user"))).await.unwrap(), socket.send(Message::Text(serde_json::to_string(&user).expect("could not serialize user"))).await,
session::SessionEvent::RemoveUser(id) => session::SessionEvent::RemoveUser(id) =>
socket.send(Message::Text(format!("{{\"remove\":{id}}}"))).await.unwrap(), socket.send(Message::Text(format!("{{\"remove\":{id}}}"))).await,
} {
tracing::debug!("websocket disconnected: {e}");
} }
} }
} }
@ -126,14 +128,14 @@ async fn ping_server(Query(options): Query<model::PingOptions>) -> Result<Json<m
} }
} }
async fn peek_server(Query(options): Query<model::PeekOptions>) -> Result<Json<Vec<model::User>>, String> { // async fn peek_server(Query(options): Query<model::PeekOptions>) -> Result<Json<Vec<model::User>>, String> {
match Session::connect( // match Session::new(
&options.host, options.port, options.username, options.password // &options.host, options.port, options.username, options.password
).await { // ).await {
Err(e) => Err(format!("could not connect to server: {e}")), // Err(e) => Err(format!("could not connect to server: {e}")),
Ok(s) => { // Ok(s) => {
s.ready().await; // s.ready().await;
Ok(Json(s.users().await)) // Ok(Json(s.users().await))
}, // },
} // }
} // }

View file

@ -1,17 +1,24 @@
use std::{borrow::Borrow, collections::HashMap, net::SocketAddr, sync::Arc}; use std::{borrow::Borrow, collections::HashMap, net::SocketAddr, sync::{atomic::AtomicBool, Arc}};
use tokio::{net::UdpSocket, sync::{broadcast, mpsc::{self, error::TrySendError}, watch, RwLock}}; use tokio::{net::UdpSocket, sync::{broadcast, RwLock}};
use crate::{model::User, tcp::{control::ControlChannel, proto}, udp::proto::{PingPacket, PongPacket}}; use crate::{model::User, tcp::{control::ControlChannel, proto}, udp::proto::{PingPacket, PongPacket}};
#[derive(Debug)] #[derive(Debug)]
pub struct Session { pub struct Session {
users: RwLock<HashMap<u32, User>>, options: Arc<SessionOptions>,
username: String, users: Arc<RwLock<HashMap<u32, User>>>,
host: String, // sync: watch::Receiver<bool>,
sync: watch::Receiver<bool>, run: Arc<AtomicBool>,
drop: mpsc::Sender<()>, events: Arc<broadcast::Sender<SessionEvent>>,
events: broadcast::Sender<SessionEvent>, }
#[derive(Debug, Clone, Default)]
pub struct SessionOptions {
pub username: String,
pub password: Option<String>,
pub host: String,
pub port: u16,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -22,11 +29,7 @@ pub enum SessionEvent {
impl Drop for Session { impl Drop for Session {
fn drop(&mut self) { fn drop(&mut self) {
match self.drop.try_send(()) { self.run.store(false, std::sync::atomic::Ordering::Relaxed);
Ok(()) => {},
Err(TrySendError::Full(())) => tracing::warn!("session stop channel full"),
Err(TrySendError::Closed(())) => tracing::warn!("session stop channel already closed"),
}
} }
} }
@ -60,115 +63,141 @@ impl Session {
}) })
} }
pub async fn ready(&self) { // pub async fn ready(&self) {
let mut sync = self.sync.clone(); // let mut sync = self.sync.clone();
loop { // loop {
if *sync.borrow() { break } // if *sync.borrow() { break }
sync.changed().await.unwrap(); // sync.changed().await.unwrap();
} // }
} // }
pub async fn users(&self) -> Vec<User> { pub async fn users(&self) -> Vec<User> {
self.users.read().await self.users.read().await
.borrow() .borrow()
.values() .values()
.filter(|u| u.name != self.username) .filter(|u| u.name != self.options.username)
.cloned() .cloned()
.collect() .collect()
} }
pub fn host(&self) -> String { pub fn host(&self) -> String {
self.host.to_string() self.options.host.to_string()
} }
pub fn events(&self) -> broadcast::Receiver<SessionEvent> { pub fn events(&self) -> broadcast::Receiver<SessionEvent> {
self.events.subscribe() self.events.subscribe()
} }
pub async fn connect(host: &str, port: Option<u16>, username: Option<String>, password: Option<String>) -> std::io::Result<Arc<Self>> { // async fn connect(&self) -> std::io::Result<()> {
let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); // Self::connect_session(self.options.clone(), self.run.clone(), self.users.clone(), self.events.clone()).await
let channel = Arc::new(ControlChannel::new(host, port).await?); // }
let version = proto::Version {
version_v1: None, async fn connect_session(
version_v2: Some(281496485429248), options: Arc<SessionOptions>,
release: Some("1.5.517".into()), run: Arc<AtomicBool>,
os: None, users: Arc<RwLock<HashMap<u32, User>>>,
os_version: None, events: Arc<broadcast::Sender<SessionEvent>>,
}; ) -> std::io::Result<()> {
let authenticate = proto::Authenticate { let channel = Arc::new(ControlChannel::new(&options.host, Some(options.port)).await?);
username: Some(username.clone()),
password,
tokens: Vec::new(),
celt_versions: Vec::new(),
opus: Some(true),
client_type: Some(1),
};
for pkt in [ for pkt in [
proto::Packet::Version(version), proto::Packet::Version(proto::Version {
proto::Packet::Authenticate(authenticate), version_v1: None,
version_v2: Some(281496485429248),
release: Some("1.5.517".into()),
os: None,
os_version: None,
}),
proto::Packet::Authenticate(proto::Authenticate {
username: Some(options.username.clone()),
password: options.password.clone(),
tokens: Vec::new(),
celt_versions: Vec::new(),
opus: Some(true),
client_type: Some(1),
}),
] { ] {
channel.send(pkt).await?; channel.send(pkt).await?;
} }
let (drop, mut stop) = mpsc::channel(1); let mut tasks = tokio::task::JoinSet::new();
let (ready, sync) = watch::channel(false);
let (events, _) = broadcast::channel(64);
let s = Arc::new(Session { let _channel = channel.clone();
drop, sync, events, let _run = run.clone();
username: username.clone(), tasks.spawn(async move {
users : RwLock::new(HashMap::new()), while _run.load(std::sync::atomic::Ordering::Relaxed) {
host: host.to_string(), match _channel.recv().await {
}); Err(e) => {
tracing::warn!("disconnected from server: {}", e);
let session = s.clone(); break;
let chan = channel.clone();
tokio::spawn(async move {
loop {
match stop.try_recv() {
Ok(()) => break,
Err(mpsc::error::TryRecvError::Empty) => {},
Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."),
}
match chan.recv().await {
Err(e) => break tracing::warn!("disconnected from server: {}", e),
// Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message),
// Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name),
Ok(proto::Packet::UserRemove(user)) => {
tracing::info!("remove user: {:?}", user);
session.users.write().await.remove(&user.session);
let _ = session.events.send(SessionEvent::RemoveUser(user.session));
}, },
Ok(proto::Packet::ServerSync(_sync)) => { Ok(proto::Packet::UserRemove(user)) => {
tracing::info!("synched: {:?}", _sync); tracing::debug!("removing user: {:?}", user);
ready.send(true).unwrap(); users.write().await.remove(&user.session);
let _ = events.send(SessionEvent::RemoveUser(user.session));
}, },
Ok(proto::Packet::UserState(user)) => { Ok(proto::Packet::UserState(user)) => {
tracing::info!("user state: {:?}", user); tracing::debug!("updating user state: {:?}", user);
let mut users = session.users.write().await; let mut users = users.write().await;
let id = user.session(); let id = user.session();
match users.get_mut(&id) { match users.get_mut(&id) {
Some(u) => u.update(user), Some(u) => u.update(user),
None => { users.insert(user.session(), User::from(user)); }, None => { users.insert(user.session(), User::from(user)); },
} }
let _ = session.events.send( let _ = events.send(
SessionEvent::AddUser(users.get(&id).cloned().expect("just inserted")) SessionEvent::AddUser(users.get(&id).cloned().expect("just inserted"))
); // if it fails nobody is listening ); // if it fails nobody is listening
}, },
Ok(pkt) => tracing::info!("ignoring packet {:?}", pkt), Ok(pkt) => tracing::debug!("ignoring packet {:?}", pkt),
}
}
users.write().await.clear();
});
tasks.spawn(async move {
while run.load(std::sync::atomic::Ordering::Relaxed) {
tokio::time::sleep(std::time::Duration::from_secs(20)).await;
if let Err(e) = channel.send(proto::Packet::Ping(proto::Ping::default())).await {
tracing::warn!("could not send ping: {e}");
break;
} }
} }
}); });
let chan = channel.clone(); while let Some(res) = tasks.join_next().await { res? }
Ok(())
}
pub fn new(host: &str, port: Option<u16>, username: Option<String>, password: Option<String>) -> Self {
let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string());
let (events, _) = broadcast::channel(64);
let s = Session {
events: Arc::new(events),
users : Arc::new(RwLock::new(HashMap::new())),
run: Arc::new(AtomicBool::new(true)),
options: Arc::new(SessionOptions {
username, password,
host: host.to_string(),
port: port.unwrap_or(64738),
}),
};
let options = s.options.clone();
let run = s.run.clone();
let users = s.users.clone();
let events = s.events.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { while run.load(std::sync::atomic::Ordering::Relaxed) {
tokio::time::sleep(std::time::Duration::from_secs(20)).await; if let Err(e) = Self::connect_session(options.clone(), run.clone(), users.clone(), events.clone()).await {
chan.send(proto::Packet::Ping(proto::Ping::default())).await.unwrap(); tracing::error!("could not connect to mumble: {e}");
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
tracing::info!("attempting to reconnect...");
} }
}); });
Ok(s) s
} }
} }