diff --git a/Cargo.toml b/Cargo.toml index 5a1da37..027681c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,9 @@ edition = "2021" [dependencies] async-recursion = "1.0.5" axum = "0.6.20" +chrono = "0.4.31" clap = { version = "4.4.6", features = ["derive"] } +lazy_static = "1.4.0" reqwest = { version = "0.11.20", features = ["json"] } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.107" diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..31960d5 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,40 @@ +use std::{sync::Arc, collections::HashMap, time::Duration}; + +use chrono::Utc; +use tokio::sync::Mutex; + +lazy_static::lazy_static! { + pub static ref CACHE : Arc> = Arc::new(Mutex::new(InstanceCache::default())); +} + +const MAX_CACHE_AGE : i64 = 86400; + +#[derive(Default)] +pub struct InstanceCache { + store: HashMap, +} + +impl InstanceCache { + pub async fn instance_metadata(&mut self, domain: &str) -> reqwest::Result { + let now = Utc::now().timestamp(); + + if let Some((age, value)) = self.store.get(domain) { + if now - age < MAX_CACHE_AGE { + return Ok(value.clone()); + } + } + + let value = reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .build()? + .get(format!("https://{}/nodeinfo/2.0.json", domain)) + .send() + .await? + .json::() + .await?; + + self.store.insert(domain.to_string(), (now, value.clone())); + + Ok(value) + } +} diff --git a/src/main.rs b/src/main.rs index 958d7fe..82b09f2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ -use std::{sync::Arc, collections::{HashMap, HashSet}, time::Duration, net::SocketAddr}; +use std::{sync::Arc, net::SocketAddr}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio::sync::Mutex; @@ -8,11 +8,37 @@ use clap::Parser; use axum::{routing::get, extract::Query, Json, Router}; +use crate::{model::{MapResult, Map}, cache::CACHE}; + +mod model; +mod cache; + #[derive(Debug, Parser)] /// an API crawling akkoma bubble instances network and creating a map struct CliArgs { - /// start domain for crawl, without proto base - domain: String, + /// start the server listening on this host + host: Option, +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let args = CliArgs::parse(); + + let app = Router::new() + .route("/crawl", get(route_crawl_domain)); + + let addr = match args.host { + Some(host) => host.parse().expect("could not parse provided host"), + None => SocketAddr::from(([127, 0, 0, 1], 18811)), + }; + + tracing::debug!("listening on {}", addr); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .expect("could not serve axum app"); } #[derive(Debug, Deserialize)] @@ -20,105 +46,28 @@ struct Params { domain: String } - -#[tokio::main] -async fn main() { - // initialize tracing - tracing_subscriber::fmt::init(); - - // build our application with a route - let app = Router::new() - .route("/crawl", get(route_scan_domain)); - - // run our app with hyper, listening globally on port 3000 - let addr = SocketAddr::from(([127, 0, 0, 1], 18811)); - tracing::debug!("listening on {}", addr); - axum::Server::bind(&addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -async fn route_scan_domain(Query(params): Query) -> Json { +async fn route_crawl_domain(Query(params): Query) -> Json { tracing::info!("starting new crawl from {}", params.domain); - let map = Arc::new(Mutex::new(Map { - scanned: HashSet::new(), - name_to_id: HashMap::new(), - counter: 0, - nodes: Vec::new(), - vertices: Vec::new(), - })); + let map = Arc::new(Mutex::new(Map::default())); scan_instance(¶ms.domain, map.clone()).await; let _map = map.lock().await; axum::Json(MapResult { - nodes: _map.nodes.clone(), - vertices: _map.vertices.clone(), + nodes: _map.get_nodes().clone(), + vertices: _map.get_vertices().clone(), }) } -struct Map { - name_to_id: HashMap, - scanned: HashSet, - nodes: Vec, - vertices: Vec, - counter: usize, -} - -impl Map { - fn scanned(&mut self, name: &str) -> bool { - let out = self.scanned.contains(name); - self.scanned.insert(name.to_string()); - out - } - - fn node(&mut self, name: String) -> usize { - match self.name_to_id.get(&name) { - Some(id) => *id, - None => { - let id = self.counter; - self.name_to_id.insert(name.clone(), id); - self.nodes.push(Node { label: name, id }); - self.counter += 1; - id - } - } - } - - fn vertex(&mut self, from_name: String, to_name: String) { - let from = self.node(from_name); - let to = self.node(to_name); - self.vertices.push(Vertex { from, to }); - } -} - -#[derive(Serialize, Clone, Debug)] -struct Node { - id: usize, - label: String, -} - -#[derive(Serialize, Clone, Debug)] -struct Vertex { - from: usize, - to: usize, -} - -#[derive(Serialize, Clone, Debug)] -struct MapResult { - nodes: Vec, - vertices: Vec, -} #[async_recursion::async_recursion] async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { - if map.lock().await.scanned(domain) { return None }; + if map.lock().await.already_scanned(domain) { return None }; tracing::debug!("scanning instance {}", domain); - let response = match instance_metadata(domain).await { + let response = match CACHE.lock().await.instance_metadata(domain).await { Ok(r) => r, Err(e) => { tracing::warn!("could not fetch metadata for {}: {}", domain, e); @@ -134,7 +83,7 @@ async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { tracing::info!("adding instance {}", node_name); - map.lock().await.node(domain.to_string()); + map.lock().await.add_node(domain); let mut tasks = Vec::new(); @@ -146,7 +95,7 @@ async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { .filter_map(|x| x.as_str().map(|x| x.to_string())) { let _map = map.clone(); - map.lock().await.vertex(domain.to_string(), bubble_instance.clone()); + map.lock().await.add_vertex(domain, &bubble_instance); tasks.push(tokio::spawn(async move { scan_instance(&bubble_instance, _map).await; })); } @@ -156,15 +105,3 @@ async fn scan_instance(domain: &str, map: Arc>) -> Option<()> { Some(()) } - -async fn instance_metadata(domain: &str) -> reqwest::Result { - reqwest::Client::builder() - .timeout(Duration::from_secs(5)) - .build()? - .get(format!("https://{}/nodeinfo/2.0.json", domain)) - .send() - .await? - .json::() - .await -} - diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..cbf44a8 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,65 @@ +use std::collections::{HashMap, HashSet}; + +use serde::Serialize; + +#[derive(Default)] +pub struct Map { + name_to_id: HashMap, + scanned: HashSet, + nodes: Vec, + vertices: Vec, + counter: usize, +} + +impl Map { + pub fn already_scanned(&mut self, domain: &str) -> bool { + let out = self.scanned.contains(domain); + self.scanned.insert(domain.to_string()); + out + } + + pub fn get_nodes(&self) -> &Vec { + &self.nodes + } + + pub fn get_vertices(&self) -> &Vec { + &self.vertices + } + + pub fn add_node(&mut self, domain: &str) -> usize { + match self.name_to_id.get(domain) { + Some(id) => *id, + None => { + let id = self.counter; + self.name_to_id.insert(domain.to_string(), id); + self.nodes.push(Node { label: domain.to_string(), id }); + self.counter += 1; + id + } + } + } + + pub fn add_vertex(&mut self, from_domain: &str, to_domain: &str) { + let from = self.add_node(from_domain); + let to = self.add_node(to_domain); + self.vertices.push(Vertex { from, to }); + } +} + +#[derive(Serialize, Clone, Debug)] +pub struct Node { + pub id: usize, + pub label: String, +} + +#[derive(Serialize, Clone, Debug)] +pub struct Vertex { + pub from: usize, + pub to: usize, +} + +#[derive(Serialize, Clone, Debug)] +pub struct MapResult { + pub nodes: Vec, + pub vertices: Vec, +}