diff --git a/src/main.rs b/src/main.rs index b573b43..e2c0968 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,12 @@ -use std::{sync::Arc, net::SocketAddr}; +use std::net::SocketAddr; use serde::Deserialize; -use tokio::sync::Mutex; - use clap::Parser; use axum::{routing::get, extract::Query, Json, Router}; -use crate::{model::{MapResult, Map}, cache::CACHE}; +use crate::{model::{MapResult, MapHandle, create_map_collector}, cache::CACHE}; mod model; mod cache; @@ -48,23 +46,15 @@ struct Params { async fn route_crawl_domain(Query(params): Query) -> Json { tracing::info!("starting new crawl from {}", params.domain); - - let map = Arc::new(Mutex::new(Map::default())); - - scan_instance(¶ms.domain, map.clone()).await; - - let _map = map.lock().await; - - axum::Json(MapResult { - nodes: _map.get_nodes().clone(), - vertices: _map.get_vertices().clone(), - }) + let (collector, handle) = create_map_collector(); + scan_instance(¶ms.domain, handle).await; + axum::Json(collector.collect().await) } #[async_recursion::async_recursion] -async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { - if map.lock().await.already_scanned(domain) { return None }; +async fn scan_instance(domain: &str, map: MapHandle) -> Option<()> { + if map.already_scanned(domain).await { return None }; tracing::debug!("scanning instance {}", domain); let response = match CACHE.instance_metadata(domain).await { @@ -87,7 +77,7 @@ async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { tracing::info!("adding instance {}", node_name); - map.lock().await.add_node(domain); + map.add_node(domain.to_string(), node_name); let mut tasks = Vec::new(); @@ -99,13 +89,9 @@ async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { .filter_map(|x| x.as_str().map(|x| x.to_string())) { let _map = map.clone(); - map.lock().await.add_vertex(domain, &bubble_instance); + map.add_vertex(domain.to_string(), bubble_instance.clone()); tasks.push(tokio::spawn(async move { scan_instance(&bubble_instance, _map).await; })); } - for t in tasks { - t.await.expect("could not join task"); - } - Some(()) } diff --git a/src/model.rs b/src/model.rs index cbf44a8..954dadc 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,55 +1,114 @@ -use std::collections::{HashMap, HashSet}; +use std::{collections::{HashSet, HashMap}, sync::Arc}; use serde::Serialize; +use tokio::sync::{mpsc, RwLock}; -#[derive(Default)] -pub struct Map { - name_to_id: HashMap, - scanned: HashSet, - nodes: Vec, - vertices: Vec, - counter: usize, +pub struct MapCollector { + nodes_rx: mpsc::UnboundedReceiver, + vertices_rx: mpsc::UnboundedReceiver, } -impl Map { - pub fn already_scanned(&mut self, domain: &str) -> bool { - let out = self.scanned.contains(domain); - self.scanned.insert(domain.to_string()); - out - } +pub fn create_map_collector() -> (MapCollector, MapHandle) { + let (nodes_tx, nodes_rx) = mpsc::unbounded_channel(); + let (vertices_tx, vertices_rx) = mpsc::unbounded_channel(); + let scanned = Arc::new(RwLock::new(HashSet::new())); + ( + MapCollector { nodes_rx, vertices_rx }, + MapHandle { nodes_tx, vertices_tx, scanned }, + ) +} - pub fn get_nodes(&self) -> &Vec { - &self.nodes - } - - pub fn get_vertices(&self) -> &Vec { - &self.vertices - } - - pub fn add_node(&mut self, domain: &str) -> usize { - match self.name_to_id.get(domain) { - Some(id) => *id, - None => { - let id = self.counter; - self.name_to_id.insert(domain.to_string(), id); - self.nodes.push(Node { label: domain.to_string(), id }); - self.counter += 1; - id +impl MapCollector { + pub async fn collect(mut self) -> MapResult { + let mut nodes_domains = Vec::new(); + let mut vertices_domains = Vec::new(); + loop { + tokio::select! { + Some(node) = self.nodes_rx.recv() => nodes_domains.push(node), + Some(vertex) = self.vertices_rx.recv() => vertices_domains.push(vertex), + else => break, } } + + tracing::info!("received all nodes and vertices, processing"); + let mut nodes_map : HashMap = HashMap::new(); + let mut nodes = Vec::new(); + let mut vertices = Vec::new(); + + for (i, node) in nodes_domains.iter().enumerate() { + nodes_map.insert( + node.domain.clone(), + Node { id: i, label: node.domain.clone(), title: node.name.clone(), value: 1 } + ); + } + + for vertex in vertices_domains { + let from = { + let node = nodes_map.get_mut(&vertex.from).expect("vertex from non existing node"); + node.value += 1; + node.id + }; + + let to = { + let node = nodes_map.get_mut(&vertex.to).expect("vertex to non existing node"); + node.value += 5; + node.id + }; + + vertices.push(Vertex { from, to }); + } + + for (_, node) in nodes_map { + nodes.push(node); + } + + MapResult { nodes, vertices } + } +} + +#[derive(Clone)] +pub struct MapHandle { + scanned: Arc>>, + nodes_tx: mpsc::UnboundedSender, + vertices_tx: mpsc::UnboundedSender, +} + +impl MapHandle { + pub async fn already_scanned(&self, domain: &str) -> bool { + let present = self.scanned.read().await.contains(domain); + if !present { self.scanned.write().await.insert(domain.to_string()); } + present } - pub fn add_vertex(&mut self, from_domain: &str, to_domain: &str) { - let from = self.add_node(from_domain); - let to = self.add_node(to_domain); - self.vertices.push(Vertex { from, to }); + pub fn add_node(&self, domain: String, name: String) { + self.nodes_tx.send(NodeDomain { domain, name }) + .expect("could not send node to collector") } + + pub fn add_vertex(&self, from: String, to: String) { + self.vertices_tx.send(VertexDomain { from, to }) + .expect("could not send vertex to collector") + } +} + +#[derive(Clone, Debug)] +pub struct NodeDomain { + pub domain: String, + pub name: String, +} + +#[derive(Clone, Debug)] +pub struct VertexDomain { + pub from: String, + pub to: String, } #[derive(Serialize, Clone, Debug)] pub struct Node { pub id: usize, pub label: String, + pub value: usize, + pub title: String, } #[derive(Serialize, Clone, Debug)] @@ -58,6 +117,7 @@ pub struct Vertex { pub to: usize, } + #[derive(Serialize, Clone, Debug)] pub struct MapResult { pub nodes: Vec,