feat: split control tx/rx, send pings from task

This commit is contained in:
əlemi 2024-02-21 18:32:08 +01:00
parent fc1da3d88d
commit c99decdd24
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 29 additions and 15 deletions

View file

@ -70,7 +70,7 @@ impl Session {
pub async fn connect(host: &str, port: Option<u16>, username: Option<String>, password: Option<String>) -> std::io::Result<Arc<Self>> {
let username = username.unwrap_or_else(|| ".mumble-stats-api".to_string());
let mut channel = ControlChannel::new(host, port).await?;
let channel = Arc::new(ControlChannel::new(host, port).await?);
let version = proto::Version {
version_v1: None,
version_v2: Some(281496485429248),
@ -105,6 +105,7 @@ impl Session {
});
let session = s.clone();
let chan = channel.clone();
tokio::spawn(async move {
loop {
match rx.try_recv() {
@ -112,7 +113,7 @@ impl Session {
Err(mpsc::error::TryRecvError::Empty) => {},
Err(mpsc::error::TryRecvError::Disconnected) => break tracing::warn!("all session dropped without stopping this task, stopping..."),
}
match channel.recv().await {
match chan.recv().await {
Err(e) => break tracing::warn!("disconnected from server: {}", e),
// Ok(tcp::proto::Packet::TextMessage(msg)) => tracing::info!("{}", msg.message),
// Ok(tcp::proto::Packet::ChannelState(channel)) => tracing::info!("discovered channel: {:?}", channel.name),
@ -128,6 +129,14 @@ impl Session {
}
});
let chan = channel.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(20)).await;
chan.send(proto::Packet::Ping(proto::Ping::default())).await.unwrap();
}
});
Ok(s)
}
}

View file

@ -1,10 +1,11 @@
use std::net::{SocketAddr, ToSocketAddrs};
use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream};
use tokio::{io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, net::TcpStream, sync::Mutex};
use tokio_native_tls::TlsStream;
pub struct ControlChannel {
stream: TlsStream<TcpStream>,
tx: Mutex<WriteHalf<TlsStream<TcpStream>>>,
rx: Mutex<ReadHalf<TlsStream<TcpStream>>>,
}
impl ControlChannel {
@ -18,24 +19,28 @@ impl ControlChannel {
.danger_accept_invalid_certs(true)
.build()
.expect("could not create TLS connector").into();
let stream = connector.connect(host, socket).await.unwrap();
Ok(ControlChannel { stream })
let stream = connector.connect(host, socket).await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let (rx, tx) = tokio::io::split(stream);
Ok(ControlChannel { tx: Mutex::new(tx), rx: Mutex::new(rx) })
}
pub async fn send(&mut self, pkt: super::proto::Packet) -> std::io::Result<()> {
pub async fn send(&self, pkt: super::proto::Packet) -> std::io::Result<()> {
let (id, buffer) = pkt.encode();
self.stream.write_u16(id).await?;
self.stream.write_u32(buffer.len() as u32).await?;
self.stream.write_all(&buffer).await?;
self.stream.flush().await?;
let mut tx = self.tx.lock().await;
tx.write_u16(id).await?;
tx.write_u32(buffer.len() as u32).await?;
tx.write_all(&buffer).await?;
tx.flush().await?;
Ok(())
}
pub async fn recv(&mut self) -> std::io::Result<super::proto::Packet> {
let id = self.stream.read_u16().await?;
let size = self.stream.read_u32().await?;
pub async fn recv(&self) -> std::io::Result<super::proto::Packet> {
let mut rx = self.rx.lock().await;
let id = rx.read_u16().await?;
let size = rx.read_u32().await?;
let mut buffer = vec![0u8; size as usize];
self.stream.read_exact(&mut buffer).await?;
rx.read_exact(&mut buffer).await?;
super::proto::Packet::decode(id, &buffer)
}
}