(WIP) Data structure refactor
Signed-off-by: Solomon Wagner <solow@solow.xyz>
This commit is contained in:
parent
00076e2b42
commit
ec63debe4f
@ -1,18 +1,21 @@
|
||||
[package]
|
||||
name = "swim-rs"
|
||||
name = "swim"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Solomon W. <solow@solow.xyz>"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.96"
|
||||
base64 = "0.22.1"
|
||||
clap = { version = "4.5.30", features = ["derive", "env"] }
|
||||
colored = "3.0.0"
|
||||
quinn = "0.11.6"
|
||||
serde = { version = "1.0.218", features = ["derive"] }
|
||||
serde_json = "1.0.139"
|
||||
tokio = { version = "1.43.0", features = ["full"] }
|
||||
toml = "0.8.20"
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"] }
|
||||
|
||||
|
283
src/main.rs
283
src/main.rs
@ -1,19 +1,22 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::Engine;
|
||||
use clap::{Parser, Subcommand};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::BTreeMap, path::PathBuf};
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Result},
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||
net::{UnixListener, UnixStream},
|
||||
sync::{self, mpsc, RwLock},
|
||||
sync::{self, broadcast, mpsc, RwLock},
|
||||
};
|
||||
#[allow(unused_imports)]
|
||||
use tracing::{error, info, instrument, warn};
|
||||
|
||||
static LICENSE: &'static str = include_str!("../LICENSE");
|
||||
|
||||
const IPC_SOCKET_PATH: &str = "/tmp/swimd.sock";
|
||||
static LICENSE: &str = include_str!("../LICENSE");
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
enum SockType {
|
||||
@ -21,22 +24,69 @@ enum SockType {
|
||||
Tcp,
|
||||
}
|
||||
|
||||
impl Display for SockType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SockType::Udp => f.write_str("UDP"),
|
||||
SockType::Tcp => f.write_str("TCP"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
struct PortMapping {
|
||||
sock_type: SockType,
|
||||
from: u32,
|
||||
to: u32,
|
||||
from: u16,
|
||||
to: u16,
|
||||
}
|
||||
|
||||
impl Display for PortMapping {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{}->{}/{}",
|
||||
self.from, self.to, self.sock_type
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
struct Remote {
|
||||
host: String,
|
||||
port: u16,
|
||||
public_key: Option<[u8; 32]>,
|
||||
default_tunnel: Tunnel,
|
||||
tunnels: Vec<Tunnel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
struct Tunnel {
|
||||
mappings: BTreeMap<u32, PortMapping>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
|
||||
struct Config {
|
||||
mappings: BTreeMap<u32, PortMapping>,
|
||||
struct Config<'a> {
|
||||
remotes: BTreeMap<String, Remote>,
|
||||
#[serde(skip)] // (<remote_name> <tunnel idx> <mapping idx>)
|
||||
flat_mappings: BTreeSet<(String, usize, &'a PortMapping)>,
|
||||
}
|
||||
|
||||
impl<'a> Config<'a> {
|
||||
fn update_mappings(&'a mut self) {
|
||||
self.flat_mappings.clear();
|
||||
for (name, rem) in self.remotes.iter() {
|
||||
for (tidx, tun) in rem.tunnels.iter().enumerate() {
|
||||
for mapping in tun.mappings.values() {
|
||||
self.flat_mappings.insert((name.clone(), tidx, mapping));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AppState {
|
||||
active_config: Arc<RwLock<Config>>,
|
||||
config: Arc<RwLock<Config>>,
|
||||
active_config: Arc<RwLock<Config<'static>>>,
|
||||
config: Arc<RwLock<Config<'static>>>,
|
||||
notify: sync::mpsc::Sender<()>, // Used to tell the ipc server to reload
|
||||
}
|
||||
|
||||
@ -61,7 +111,7 @@ enum ClientResp<S: AsRef<str>> {
|
||||
}
|
||||
|
||||
impl<S: AsRef<str> + Serialize> ClientResp<S> {
|
||||
async fn write_to<W>(&self, mut writer: W) -> Result<()>
|
||||
async fn write_to<W>(&self, mut writer: W) -> std::io::Result<()>
|
||||
where
|
||||
W: AsyncWriteExt + std::marker::Unpin,
|
||||
{
|
||||
@ -87,23 +137,24 @@ async fn handle_ipc_client(mut conn: UnixStream, app: Arc<AppState>) {
|
||||
let resp: ClientResp<_> = match serde_json::from_str::<IPCRequest>(&buf) {
|
||||
Ok(req) => match req.op {
|
||||
ClientOp::Del(ni) => {
|
||||
let cfg = config.write().await;
|
||||
let km: Vec<String> = ni
|
||||
.iter()
|
||||
.filter(|i| cfg.mappings.contains_key(i))
|
||||
.map(|n| n.to_string())
|
||||
.collect();
|
||||
let kml = km.join(",");
|
||||
if !km.len() == 0 {
|
||||
error!("Bad rule ids when deleting: {}", kml);
|
||||
|
||||
ClientResp::Error(format!("Some route IDs do not exist: {}", kml))
|
||||
} else {
|
||||
for i in ni {
|
||||
info!("Removed ruleset {i}!");
|
||||
}
|
||||
ClientResp::Ok(Some(format!("Removed routes: {kml}")))
|
||||
}
|
||||
// let cfg = config.write().await;
|
||||
// let km: Vec<String> = ni
|
||||
// .iter()
|
||||
// .filter(|i| cfg.mappings.contains_key(i))
|
||||
// .map(|n| n.to_string())
|
||||
// .collect();
|
||||
// let kml = km.join(",");
|
||||
// if !km.len() == 0 {
|
||||
// error!("Bad rule ids when deleting: {}", kml);
|
||||
//
|
||||
// ClientResp::Error(format!("Some route IDs do not exist: {}", kml))
|
||||
// } else {
|
||||
// for i in ni {
|
||||
// info!("Removed ruleset {i}!");
|
||||
// }
|
||||
// ClientResp::Ok(Some(format!("Removed routes: {kml}")))
|
||||
// }
|
||||
}
|
||||
ClientOp::Export => {
|
||||
let cfg = serde_json::to_string(&*config.read().await).unwrap();
|
||||
@ -150,11 +201,23 @@ async fn ipc_server(app: Arc<AppState>, sock: UnixListener) {
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_listeners(cfg: &Config) {
|
||||
for (id, mapping) in &cfg.mappings {
|
||||
info!("Starting port mapping {id} : {mapping}");
|
||||
let handle = tokio::spawn(listener_run);
|
||||
}
|
||||
}
|
||||
|
||||
async fn quic_server(app: Arc<AppState>, mut notify: mpsc::Receiver<()>) {
|
||||
let mut handles = Vec::new();
|
||||
let (tx, rx) = broadcast::channel::<()>(5);
|
||||
|
||||
while notify.recv().await.is_some() {
|
||||
let mut active = app.active_config.write().await;
|
||||
let cfg = app.config.read().await.clone();
|
||||
*active = cfg;
|
||||
let a = app.active_config.read().await;
|
||||
tx.send(());
|
||||
info!("Reloaded config!");
|
||||
}
|
||||
}
|
||||
@ -179,59 +242,155 @@ enum Connection {
|
||||
Info,
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
struct ServerArgs {
|
||||
#[arg(short, long, help = "Causes the server to fork from the shell")]
|
||||
daemonize: bool,
|
||||
#[arg(short, long, help = "Change the control socket path")]
|
||||
socket_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
struct PubKey([u8; 32]);
|
||||
|
||||
impl FromStr for PubKey {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
if s.len() != 43 {
|
||||
return Err(format!(
|
||||
"Public key must be exactly 43 characters, got {}",
|
||||
s.len()
|
||||
));
|
||||
}
|
||||
|
||||
let decoded = base64::engine::general_purpose::STANDARD
|
||||
.decode(s)
|
||||
.map_err(|e| format!("Invalid base64 string: {}", e))?;
|
||||
|
||||
if decoded.len() != 32 {
|
||||
return Err(format!(
|
||||
"Decoded key must have a length of 32, got {}",
|
||||
decoded.len()
|
||||
));
|
||||
}
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(&decoded);
|
||||
Ok(PubKey(key))
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PubKey {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&base64::engine::general_purpose::STANDARD.encode(self.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PubKey {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
PubKey::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, clap::Args)]
|
||||
struct RemoteAddArgs {
|
||||
#[arg(help = "Host ip/domain of a server")]
|
||||
host: String,
|
||||
#[arg(short, long, help = "Chatter port for the server")]
|
||||
port: u8,
|
||||
#[arg(short, long, help = "Pubkey for the remote (skips key exchange)")]
|
||||
pubkey: PubKey,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
#[command(
|
||||
about = "Operations related to remotes (peers)",
|
||||
long_about = "Tunnels are created between remotes"
|
||||
)]
|
||||
enum RemoteCmd {
|
||||
Add(RemoteAddArgs),
|
||||
List,
|
||||
Rm,
|
||||
Info,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
#[command(subcommand)]
|
||||
Connection(Connection),
|
||||
Remote,
|
||||
#[command(subcommand)]
|
||||
Remote(RemoteCmd),
|
||||
License,
|
||||
Version,
|
||||
Server(ServerArgs),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author = "Solomon W.", version, about = "The tunneling firewall")]
|
||||
#[command(about = "The tunneling firewall")]
|
||||
struct Cli {
|
||||
#[arg(long, help = "Run as the swim daemon")]
|
||||
daemonize: bool,
|
||||
#[arg(long, help = "Display the license")]
|
||||
license: bool,
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
static NAME: &str = env!("CARGO_PKG_NAME");
|
||||
static VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
static AUTHOR: &str = env!("CARGO_PKG_AUTHORS");
|
||||
|
||||
fn get_socket_path() -> PathBuf {
|
||||
if std::env::var("XDG_RUNTIME_DIR").is_ok() {
|
||||
std::env::var("XDG_RUNTIME_DIR")
|
||||
.map(|dir| PathBuf::from(dir).join("swim.sock"))
|
||||
.unwrap()
|
||||
} else if std::env::var("SWIM_SYSTEM_DAEMON").is_ok() {
|
||||
PathBuf::from("/run/swim.sock")
|
||||
} else {
|
||||
// Fallback to /tmp
|
||||
PathBuf::from("/tmp/swim.sock")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
if cli.license {
|
||||
println!("{}", LICENSE);
|
||||
std::process::exit(0);
|
||||
}
|
||||
match cli.command {
|
||||
Commands::Connection(_connection) => println!("Hello swim!"),
|
||||
Commands::License => println!("author(s): {}\n\n{}", AUTHOR, LICENSE),
|
||||
Commands::Version => println!("{} {}", NAME, VERSION),
|
||||
Commands::Server(cmd) => {
|
||||
let sock_path = cmd
|
||||
.socket_path
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| Some(get_socket_path()))
|
||||
.unwrap();
|
||||
|
||||
if cli.daemonize {
|
||||
let sock = UnixListener::bind(IPC_SOCKET_PATH).expect("Couldn't bind to unix socket!");
|
||||
let (tx, rx) = mpsc::channel(10);
|
||||
let sock = UnixListener::bind(sock_path).expect("Couldn't bind to unix socket!");
|
||||
let (tx, rx) = mpsc::channel(10);
|
||||
|
||||
let config = Config::default();
|
||||
let config = Config::default();
|
||||
|
||||
let app = Arc::new(AppState {
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
notify: tx,
|
||||
active_config: Arc::new(RwLock::new(Config::default())),
|
||||
});
|
||||
let app = Arc::new(AppState {
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
notify: tx,
|
||||
active_config: Arc::new(RwLock::new(Config::default())),
|
||||
});
|
||||
|
||||
// TODO: Early config init here
|
||||
// TODO: Early config init here
|
||||
|
||||
let ipc_handle = tokio::spawn(ipc_server(app.clone(), sock));
|
||||
let quic_handle = tokio::spawn(quic_server(app.clone(), rx));
|
||||
let ipc_handle = tokio::spawn(ipc_server(app.clone(), sock));
|
||||
let quic_handle = tokio::spawn(quic_server(app.clone(), rx));
|
||||
|
||||
ipc_handle.await.unwrap();
|
||||
quic_handle.await.unwrap();
|
||||
}
|
||||
|
||||
if let Some(cmd) = cli.command {
|
||||
match cmd {
|
||||
Commands::Connection(_connection) => println!("Hello swim!"),
|
||||
_ => todo!("Not yet implemented!"),
|
||||
ipc_handle.await.unwrap();
|
||||
quic_handle.await.unwrap();
|
||||
}
|
||||
_ => todo!("Not yet implemented!"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user