feat: split control tx/rx, send pings from task

This commit is contained in:
əlemi 2024-02-21 18:32:08 +01:00
parent fc1da3d88d
commit c99decdd24
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 29 additions and 15 deletions

View file

@ -70,7 +70,7 @@ impl Session {
pub async fn connect(host: &str, port: Option<u16>, username: Option<String>, password: Option<String>) -> std::io::Result<Arc<Self>> { pub async fn connect(host: &str, port: Option<u16>, username: Option<String>, password: Option<String>) -> std::io::Result<Arc<Self>> {
let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string());
let mut channel = ControlChannel::new(host, port).await?; let channel = Arc::new(ControlChannel::new(host, port).await?);
let version = proto::Version { let version = proto::Version {
version_v1: None, version_v1: None,
version_v2: Some(281496485429248), version_v2: Some(281496485429248),
@ -105,6 +105,7 @@ impl Session {
}); });
let session = s.clone(); let session = s.clone();
let chan = channel.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
match rx.try_recv() { match rx.try_recv() {
@ -112,7 +113,7 @@ impl Session {
Err(mpsc::error::TryRecvError::Empty) => {}, Err(mpsc::error::TryRecvError::Empty) => {},
Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."), Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."),
} }
match channel.recv().await { match chan.recv().await {
Err(e) => break tracing::warn!("disconnected from server: {}", e), Err(e) => break tracing::warn!("disconnected from server: {}", e),
// Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message), // Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message),
// Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name), // Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name),
@ -128,6 +129,14 @@ impl Session {
} }
}); });
let chan = channel.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(20)).await;
chan.send(proto::Packet::Ping(proto::Ping::default())).await.unwrap();
}
});
Ok(s) Ok(s)
} }
} }

View file

@ -1,10 +1,11 @@
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream}; use tokio::{io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, net::TcpStream, sync::Mutex};
use tokio_native_tls::TlsStream; use tokio_native_tls::TlsStream;
pub struct ControlChannel { pub struct ControlChannel {
stream: TlsStream<TcpStream>, tx: Mutex<WriteHalf<TlsStream<TcpStream>>>,
rx: Mutex<ReadHalf<TlsStream<TcpStream>>>,
} }
impl ControlChannel { impl ControlChannel {
@ -18,24 +19,28 @@ impl ControlChannel {
.danger_accept_invalid_certs(true) .danger_accept_invalid_certs(true)
.build() .build()
.expect("could not create TLS connector").into(); .expect("could not create TLS connector").into();
let stream = connector.connect(host, socket).await.unwrap(); let stream = connector.connect(host, socket).await
Ok(ControlChannel { stream }) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let (rx, tx) = tokio::io::split(stream);
Ok(ControlChannel { tx: Mutex::new(tx), rx: Mutex::new(rx) })
} }
pub async fn send(&mut self, pkt: super::proto::Packet) -> std::io::Result<()> { pub async fn send(&self, pkt: super::proto::Packet) -> std::io::Result<()> {
let (id, buffer) = pkt.encode(); let (id, buffer) = pkt.encode();
self.stream.write_u16(id).await?; let mut tx = self.tx.lock().await;
self.stream.write_u32(buffer.len() as u32).await?; tx.write_u16(id).await?;
self.stream.write_all(&buffer).await?; tx.write_u32(buffer.len() as u32).await?;
self.stream.flush().await?; tx.write_all(&buffer).await?;
tx.flush().await?;
Ok(()) Ok(())
} }
pub async fn recv(&mut self) -> std::io::Result<super::proto::Packet> { pub async fn recv(&self) -> std::io::Result<super::proto::Packet> {
let id = self.stream.read_u16().await?; let mut rx = self.rx.lock().await;
let size = self.stream.read_u32().await?; let id = rx.read_u16().await?;
let size = rx.read_u32().await?;
let mut buffer = vec![0u8; size as usize]; let mut buffer = vec![0u8; size as usize];
self.stream.read_exact(&mut buffer).await?; rx.read_exact(&mut buffer).await?;
super::proto::Packet::decode(id, &buffer) super::proto::Packet::decode(id, &buffer)
} }
} }