(WIP) Data structure refactor

Signed-off-by: Solomon Wagner <solow@solow.xyz>
This commit is contained in:
Solomon W. 2025-02-27 12:16:20 -05:00
parent 00076e2b42
commit ec63debe4f
2 changed files with 225 additions and 63 deletions

View File

@ -1,18 +1,21 @@
[package] [package]
name = "swim-rs" name = "swim"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
authors = ["Solomon W. <solow@solow.xyz>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
anyhow = "1.0.96" anyhow = "1.0.96"
base64 = "0.22.1"
clap = { version = "4.5.30", features = ["derive", "env"] } clap = { version = "4.5.30", features = ["derive", "env"] }
colored = "3.0.0" colored = "3.0.0"
quinn = "0.11.6" quinn = "0.11.6"
serde = { version = "1.0.218", features = ["derive"] } serde = { version = "1.0.218", features = ["derive"] }
serde_json = "1.0.139" serde_json = "1.0.139"
tokio = { version = "1.43.0", features = ["full"] } tokio = { version = "1.43.0", features = ["full"] }
toml = "0.8.20"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"] }

View File

@ -1,19 +1,22 @@
use std::collections::BTreeMap; use anyhow::{Context, Result};
use std::sync::Arc; use base64::Engine;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::fmt::Display;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::{collections::BTreeMap, path::PathBuf};
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Result}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::{UnixListener, UnixStream}, net::{UnixListener, UnixStream},
sync::{self, mpsc, RwLock}, sync::{self, broadcast, mpsc, RwLock},
}; };
#[allow(unused_imports)] #[allow(unused_imports)]
use tracing::{error, info, instrument, warn}; use tracing::{error, info, instrument, warn};
static LICENSE: &'static str = include_str!("../LICENSE"); static LICENSE: &str = include_str!("../LICENSE");
const IPC_SOCKET_PATH: &str = "/tmp/swimd.sock";
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
enum SockType { enum SockType {
@ -21,22 +24,69 @@ enum SockType {
Tcp, Tcp,
} }
impl Display for SockType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SockType::Udp => f.write_str("UDP"),
SockType::Tcp => f.write_str("TCP"),
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
struct PortMapping { struct PortMapping {
sock_type: SockType, sock_type: SockType,
from: u32, from: u16,
to: u32, to: u16,
}
impl Display for PortMapping {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"{}->{}/{}",
self.from, self.to, self.sock_type
))
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
struct Remote {
host: String,
port: u16,
public_key: Option<[u8; 32]>,
default_tunnel: Tunnel,
tunnels: Vec<Tunnel>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
struct Tunnel {
mappings: BTreeMap<u32, PortMapping>,
} }
#[derive(Debug, Serialize, Deserialize, Default, Clone)] #[derive(Debug, Serialize, Deserialize, Default, Clone)]
struct Config { struct Config<'a> {
mappings: BTreeMap<u32, PortMapping>, remotes: BTreeMap<String, Remote>,
#[serde(skip)] // (<remote_name> <tunnel idx> <mapping idx>)
flat_mappings: BTreeSet<(String, usize, &'a PortMapping)>,
}
impl<'a> Config<'a> {
fn update_mappings(&'a mut self) {
self.flat_mappings.clear();
for (name, rem) in self.remotes.iter() {
for (tidx, tun) in rem.tunnels.iter().enumerate() {
for mapping in tun.mappings.values() {
self.flat_mappings.insert((name.clone(), tidx, mapping));
}
}
}
}
} }
#[derive(Debug)] #[derive(Debug)]
struct AppState { struct AppState {
active_config: Arc<RwLock<Config>>, active_config: Arc<RwLock<Config<'static>>>,
config: Arc<RwLock<Config>>, config: Arc<RwLock<Config<'static>>>,
notify: sync::mpsc::Sender<()>, // Used to tell the ipc server to reload notify: sync::mpsc::Sender<()>, // Used to tell the ipc server to reload
} }
@ -61,7 +111,7 @@ enum ClientResp<S: AsRef<str>> {
} }
impl<S: AsRef<str> + Serialize> ClientResp<S> { impl<S: AsRef<str> + Serialize> ClientResp<S> {
async fn write_to<W>(&self, mut writer: W) -> Result<()> async fn write_to<W>(&self, mut writer: W) -> std::io::Result<()>
where where
W: AsyncWriteExt + std::marker::Unpin, W: AsyncWriteExt + std::marker::Unpin,
{ {
@ -87,23 +137,24 @@ async fn handle_ipc_client(mut conn: UnixStream, app: Arc<AppState>) {
let resp: ClientResp<_> = match serde_json::from_str::<IPCRequest>(&buf) { let resp: ClientResp<_> = match serde_json::from_str::<IPCRequest>(&buf) {
Ok(req) => match req.op { Ok(req) => match req.op {
ClientOp::Del(ni) => { ClientOp::Del(ni) => {
let cfg = config.write().await;
let km: Vec<String> = ni
.iter()
.filter(|i| cfg.mappings.contains_key(i))
.map(|n| n.to_string())
.collect();
let kml = km.join(",");
if !km.len() == 0 {
error!("Bad rule ids when deleting: {}", kml);
ClientResp::Error(format!("Some route IDs do not exist: {}", kml)) // let cfg = config.write().await;
} else { // let km: Vec<String> = ni
for i in ni { // .iter()
info!("Removed ruleset {i}!"); // .filter(|i| cfg.mappings.contains_key(i))
} // .map(|n| n.to_string())
ClientResp::Ok(Some(format!("Removed routes: {kml}"))) // .collect();
} // let kml = km.join(",");
// if !km.len() == 0 {
// error!("Bad rule ids when deleting: {}", kml);
//
// ClientResp::Error(format!("Some route IDs do not exist: {}", kml))
// } else {
// for i in ni {
// info!("Removed ruleset {i}!");
// }
// ClientResp::Ok(Some(format!("Removed routes: {kml}")))
// }
} }
ClientOp::Export => { ClientOp::Export => {
let cfg = serde_json::to_string(&*config.read().await).unwrap(); let cfg = serde_json::to_string(&*config.read().await).unwrap();
@ -150,11 +201,23 @@ async fn ipc_server(app: Arc<AppState>, sock: UnixListener) {
} }
} }
fn spawn_listeners(cfg: &Config) {
for (id, mapping) in &cfg.mappings {
info!("Starting port mapping {id} : {mapping}");
let handle = tokio::spawn(listener_run);
}
}
async fn quic_server(app: Arc<AppState>, mut notify: mpsc::Receiver<()>) { async fn quic_server(app: Arc<AppState>, mut notify: mpsc::Receiver<()>) {
let mut handles = Vec::new();
let (tx, rx) = broadcast::channel::<()>(5);
while notify.recv().await.is_some() { while notify.recv().await.is_some() {
let mut active = app.active_config.write().await; let mut active = app.active_config.write().await;
let cfg = app.config.read().await.clone(); let cfg = app.config.read().await.clone();
*active = cfg; *active = cfg;
let a = app.active_config.read().await;
tx.send(());
info!("Reloaded config!"); info!("Reloaded config!");
} }
} }
@ -179,35 +242,136 @@ enum Connection {
Info, Info,
} }
#[derive(Debug, clap::Args)]
struct ServerArgs {
#[arg(short, long, help = "Causes the server to fork from the shell")]
daemonize: bool,
#[arg(short, long, help = "Change the control socket path")]
socket_path: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
struct PubKey([u8; 32]);
impl FromStr for PubKey {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.len() != 43 {
return Err(format!(
"Public key must be exactly 43 characters, got {}",
s.len()
));
}
let decoded = base64::engine::general_purpose::STANDARD
.decode(s)
.map_err(|e| format!("Invalid base64 string: {}", e))?;
if decoded.len() != 32 {
return Err(format!(
"Decoded key must have a length of 32, got {}",
decoded.len()
));
}
let mut key = [0u8; 32];
key.copy_from_slice(&decoded);
Ok(PubKey(key))
}
}
impl Serialize for PubKey {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&base64::engine::general_purpose::STANDARD.encode(self.0))
}
}
impl<'de> Deserialize<'de> for PubKey {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
PubKey::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, clap::Args)]
struct RemoteAddArgs {
#[arg(help = "Host ip/domain of a server")]
host: String,
#[arg(short, long, help = "Chatter port for the server")]
port: u8,
#[arg(short, long, help = "Pubkey for the remote (skips key exchange)")]
pubkey: PubKey,
}
#[derive(Subcommand, Debug)]
#[command(
about = "Operations related to remotes (peers)",
long_about = "Tunnels are created between remotes"
)]
enum RemoteCmd {
Add(RemoteAddArgs),
List,
Rm,
Info,
}
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
enum Commands { enum Commands {
#[command(subcommand)] #[command(subcommand)]
Connection(Connection), Connection(Connection),
Remote, #[command(subcommand)]
Remote(RemoteCmd),
License,
Version,
Server(ServerArgs),
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author = "Solomon W.", version, about = "The tunneling firewall")] #[command(about = "The tunneling firewall")]
struct Cli { struct Cli {
#[arg(long, help = "Run as the swim daemon")]
daemonize: bool,
#[arg(long, help = "Display the license")]
license: bool,
#[command(subcommand)] #[command(subcommand)]
command: Option<Commands>, command: Commands,
}
static NAME: &str = env!("CARGO_PKG_NAME");
static VERSION: &str = env!("CARGO_PKG_VERSION");
static AUTHOR: &str = env!("CARGO_PKG_AUTHORS");
fn get_socket_path() -> PathBuf {
if std::env::var("XDG_RUNTIME_DIR").is_ok() {
std::env::var("XDG_RUNTIME_DIR")
.map(|dir| PathBuf::from(dir).join("swim.sock"))
.unwrap()
} else if std::env::var("SWIM_SYSTEM_DAEMON").is_ok() {
PathBuf::from("/run/swim.sock")
} else {
// Fallback to /tmp
PathBuf::from("/tmp/swim.sock")
}
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
if cli.license { match cli.command {
println!("{}", LICENSE); Commands::Connection(_connection) => println!("Hello swim!"),
std::process::exit(0); Commands::License => println!("author(s): {}\n\n{}", AUTHOR, LICENSE),
} Commands::Version => println!("{} {}", NAME, VERSION),
Commands::Server(cmd) => {
let sock_path = cmd
.socket_path
.map(PathBuf::from)
.or_else(|| Some(get_socket_path()))
.unwrap();
if cli.daemonize { let sock = UnixListener::bind(sock_path).expect("Couldn't bind to unix socket!");
let sock = UnixListener::bind(IPC_SOCKET_PATH).expect("Couldn't bind to unix socket!");
let (tx, rx) = mpsc::channel(10); let (tx, rx) = mpsc::channel(10);
let config = Config::default(); let config = Config::default();
@ -226,12 +390,7 @@ async fn main() -> Result<()> {
ipc_handle.await.unwrap(); ipc_handle.await.unwrap();
quic_handle.await.unwrap(); quic_handle.await.unwrap();
} }
if let Some(cmd) = cli.command {
match cmd {
Commands::Connection(_connection) => println!("Hello swim!"),
_ => todo!("Not yet implemented!"), _ => todo!("Not yet implemented!"),
} }
}
Ok(()) Ok(())
} }