From c99decdd240e76c3734c933a3276dc349cd408a4 Mon Sep 17 00:00:00 2001 From: alemi Date: Wed, 21 Feb 2024 18:32:08 +0100 Subject: [PATCH] feat: split control tx/rx, send pings from task --- src/session.rs | 13 +++++++++++-- src/tcp/control.rs | 31 ++++++++++++++++++------------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/session.rs b/src/session.rs index e05d869..2d00b3d 100644 --- a/src/session.rs +++ b/src/session.rs @@ -70,7 +70,7 @@ impl Session { pub async fn connect(host: &str, port: Option, username: Option, password: Option) -> std::io::Result> { let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string()); - let mut channel = ControlChannel::new(host, port).await?; + let channel = Arc::new(ControlChannel::new(host, port).await?); let version = proto::Version { version_v1: None, version_v2: Some(281496485429248), @@ -105,6 +105,7 @@ impl Session { }); let session = s.clone(); + let chan = channel.clone(); tokio::spawn(async move { loop { match rx.try_recv() { @@ -112,7 +113,7 @@ impl Session { Err(mpsc::error::TryRecvError::Empty) => {}, Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."), } - match channel.recv().await { + match chan.recv().await { Err(e) => break tracing::warn!("disconnected from server: {}", e), // Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message), // Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name), @@ -128,6 +129,14 @@ impl Session { } }); + let chan = channel.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(20)).await; + chan.send(proto::Packet::Ping(proto::Ping::default())).await.unwrap(); + } + }); + Ok(s) } } diff --git a/src/tcp/control.rs b/src/tcp/control.rs index 16e8af2..12d5bd1 100644 --- a/src/tcp/control.rs +++ b/src/tcp/control.rs @@ -1,10 +1,11 @@ use std::net::{SocketAddr, ToSocketAddrs}; -use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream}; +use tokio::{io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, net::TcpStream, sync::Mutex}; use tokio_native_tls::TlsStream; pub struct ControlChannel { - stream: TlsStream, + tx: Mutex>>, + rx: Mutex>>, } impl ControlChannel { @@ -18,24 +19,28 @@ impl ControlChannel { .danger_accept_invalid_certs(true) .build() .expect("could not create TLS connector").into(); - let stream = connector.connect(host, socket).await.unwrap(); - Ok(ControlChannel { stream }) + let stream = connector.connect(host, socket).await + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let (rx, tx) = tokio::io::split(stream); + Ok(ControlChannel { tx: Mutex::new(tx), rx: Mutex::new(rx) }) } - pub async fn send(&mut self, pkt: super::proto::Packet) -> std::io::Result<()> { + pub async fn send(&self, pkt: super::proto::Packet) -> std::io::Result<()> { let (id, buffer) = pkt.encode(); - self.stream.write_u16(id).await?; - self.stream.write_u32(buffer.len() as u32).await?; - self.stream.write_all(&buffer).await?; - self.stream.flush().await?; + let mut tx = self.tx.lock().await; + tx.write_u16(id).await?; + tx.write_u32(buffer.len() as u32).await?; + tx.write_all(&buffer).await?; + tx.flush().await?; Ok(()) } - pub async fn recv(&mut self) -> std::io::Result { - let id = self.stream.read_u16().await?; - let size = self.stream.read_u32().await?; + pub async fn recv(&self) -> std::io::Result { + let mut rx = self.rx.lock().await; + let id = rx.read_u16().await?; + let size = rx.read_u32().await?; let mut buffer = vec![0u8; size as usize]; - self.stream.read_exact(&mut buffer).await?; + rx.read_exact(&mut buffer).await?; super::proto::Packet::decode(id, &buffer) } }