(WIP) Data structure refactor
Signed-off-by: Solomon Wagner <solow@solow.xyz>
This commit is contained in:
parent
00076e2b42
commit
ec63debe4f
@ -1,18 +1,21 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "swim-rs"
|
name = "swim"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
authors = ["Solomon W. <solow@solow.xyz>"]
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.96"
|
anyhow = "1.0.96"
|
||||||
|
base64 = "0.22.1"
|
||||||
clap = { version = "4.5.30", features = ["derive", "env"] }
|
clap = { version = "4.5.30", features = ["derive", "env"] }
|
||||||
colored = "3.0.0"
|
colored = "3.0.0"
|
||||||
quinn = "0.11.6"
|
quinn = "0.11.6"
|
||||||
serde = { version = "1.0.218", features = ["derive"] }
|
serde = { version = "1.0.218", features = ["derive"] }
|
||||||
serde_json = "1.0.139"
|
serde_json = "1.0.139"
|
||||||
tokio = { version = "1.43.0", features = ["full"] }
|
tokio = { version = "1.43.0", features = ["full"] }
|
||||||
|
toml = "0.8.20"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"] }
|
tracing-subscriber = { version = "0.3.19", features = ["fmt", "env-filter"] }
|
||||||
|
|
||||||
|
283
src/main.rs
283
src/main.rs
@ -1,19 +1,22 @@
|
|||||||
use std::collections::BTreeMap;
|
use anyhow::{Context, Result};
|
||||||
use std::sync::Arc;
|
use base64::Engine;
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::BTreeSet;
|
||||||
|
use std::fmt::Display;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::{collections::BTreeMap, path::PathBuf};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Result},
|
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||||
net::{UnixListener, UnixStream},
|
net::{UnixListener, UnixStream},
|
||||||
sync::{self, mpsc, RwLock},
|
sync::{self, broadcast, mpsc, RwLock},
|
||||||
};
|
};
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use tracing::{error, info, instrument, warn};
|
use tracing::{error, info, instrument, warn};
|
||||||
|
|
||||||
static LICENSE: &'static str = include_str!("../LICENSE");
|
static LICENSE: &str = include_str!("../LICENSE");
|
||||||
|
|
||||||
const IPC_SOCKET_PATH: &str = "/tmp/swimd.sock";
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||||
enum SockType {
|
enum SockType {
|
||||||
@ -21,22 +24,69 @@ enum SockType {
|
|||||||
Tcp,
|
Tcp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Display for SockType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
SockType::Udp => f.write_str("UDP"),
|
||||||
|
SockType::Tcp => f.write_str("TCP"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||||
struct PortMapping {
|
struct PortMapping {
|
||||||
sock_type: SockType,
|
sock_type: SockType,
|
||||||
from: u32,
|
from: u16,
|
||||||
to: u32,
|
to: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for PortMapping {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_fmt(format_args!(
|
||||||
|
"{}->{}/{}",
|
||||||
|
self.from, self.to, self.sock_type
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||||
|
struct Remote {
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
public_key: Option<[u8; 32]>,
|
||||||
|
default_tunnel: Tunnel,
|
||||||
|
tunnels: Vec<Tunnel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||||
|
struct Tunnel {
|
||||||
|
mappings: BTreeMap<u32, PortMapping>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
|
||||||
struct Config {
|
struct Config<'a> {
|
||||||
mappings: BTreeMap<u32, PortMapping>,
|
remotes: BTreeMap<String, Remote>,
|
||||||
|
#[serde(skip)] // (<remote_name> <tunnel idx> <mapping idx>)
|
||||||
|
flat_mappings: BTreeSet<(String, usize, &'a PortMapping)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Config<'a> {
|
||||||
|
fn update_mappings(&'a mut self) {
|
||||||
|
self.flat_mappings.clear();
|
||||||
|
for (name, rem) in self.remotes.iter() {
|
||||||
|
for (tidx, tun) in rem.tunnels.iter().enumerate() {
|
||||||
|
for mapping in tun.mappings.values() {
|
||||||
|
self.flat_mappings.insert((name.clone(), tidx, mapping));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AppState {
|
struct AppState {
|
||||||
active_config: Arc<RwLock<Config>>,
|
active_config: Arc<RwLock<Config<'static>>>,
|
||||||
config: Arc<RwLock<Config>>,
|
config: Arc<RwLock<Config<'static>>>,
|
||||||
notify: sync::mpsc::Sender<()>, // Used to tell the ipc server to reload
|
notify: sync::mpsc::Sender<()>, // Used to tell the ipc server to reload
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,7 +111,7 @@ enum ClientResp<S: AsRef<str>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<S: AsRef<str> + Serialize> ClientResp<S> {
|
impl<S: AsRef<str> + Serialize> ClientResp<S> {
|
||||||
async fn write_to<W>(&self, mut writer: W) -> Result<()>
|
async fn write_to<W>(&self, mut writer: W) -> std::io::Result<()>
|
||||||
where
|
where
|
||||||
W: AsyncWriteExt + std::marker::Unpin,
|
W: AsyncWriteExt + std::marker::Unpin,
|
||||||
{
|
{
|
||||||
@ -87,23 +137,24 @@ async fn handle_ipc_client(mut conn: UnixStream, app: Arc<AppState>) {
|
|||||||
let resp: ClientResp<_> = match serde_json::from_str::<IPCRequest>(&buf) {
|
let resp: ClientResp<_> = match serde_json::from_str::<IPCRequest>(&buf) {
|
||||||
Ok(req) => match req.op {
|
Ok(req) => match req.op {
|
||||||
ClientOp::Del(ni) => {
|
ClientOp::Del(ni) => {
|
||||||
let cfg = config.write().await;
|
|
||||||
let km: Vec<String> = ni
|
|
||||||
.iter()
|
|
||||||
.filter(|i| cfg.mappings.contains_key(i))
|
|
||||||
.map(|n| n.to_string())
|
|
||||||
.collect();
|
|
||||||
let kml = km.join(",");
|
|
||||||
if !km.len() == 0 {
|
|
||||||
error!("Bad rule ids when deleting: {}", kml);
|
|
||||||
|
|
||||||
ClientResp::Error(format!("Some route IDs do not exist: {}", kml))
|
// let cfg = config.write().await;
|
||||||
} else {
|
// let km: Vec<String> = ni
|
||||||
for i in ni {
|
// .iter()
|
||||||
info!("Removed ruleset {i}!");
|
// .filter(|i| cfg.mappings.contains_key(i))
|
||||||
}
|
// .map(|n| n.to_string())
|
||||||
ClientResp::Ok(Some(format!("Removed routes: {kml}")))
|
// .collect();
|
||||||
}
|
// let kml = km.join(",");
|
||||||
|
// if !km.len() == 0 {
|
||||||
|
// error!("Bad rule ids when deleting: {}", kml);
|
||||||
|
//
|
||||||
|
// ClientResp::Error(format!("Some route IDs do not exist: {}", kml))
|
||||||
|
// } else {
|
||||||
|
// for i in ni {
|
||||||
|
// info!("Removed ruleset {i}!");
|
||||||
|
// }
|
||||||
|
// ClientResp::Ok(Some(format!("Removed routes: {kml}")))
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
ClientOp::Export => {
|
ClientOp::Export => {
|
||||||
let cfg = serde_json::to_string(&*config.read().await).unwrap();
|
let cfg = serde_json::to_string(&*config.read().await).unwrap();
|
||||||
@ -150,11 +201,23 @@ async fn ipc_server(app: Arc<AppState>, sock: UnixListener) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spawn_listeners(cfg: &Config) {
|
||||||
|
for (id, mapping) in &cfg.mappings {
|
||||||
|
info!("Starting port mapping {id} : {mapping}");
|
||||||
|
let handle = tokio::spawn(listener_run);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn quic_server(app: Arc<AppState>, mut notify: mpsc::Receiver<()>) {
|
async fn quic_server(app: Arc<AppState>, mut notify: mpsc::Receiver<()>) {
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
let (tx, rx) = broadcast::channel::<()>(5);
|
||||||
|
|
||||||
while notify.recv().await.is_some() {
|
while notify.recv().await.is_some() {
|
||||||
let mut active = app.active_config.write().await;
|
let mut active = app.active_config.write().await;
|
||||||
let cfg = app.config.read().await.clone();
|
let cfg = app.config.read().await.clone();
|
||||||
*active = cfg;
|
*active = cfg;
|
||||||
|
let a = app.active_config.read().await;
|
||||||
|
tx.send(());
|
||||||
info!("Reloaded config!");
|
info!("Reloaded config!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -179,59 +242,155 @@ enum Connection {
|
|||||||
Info,
|
Info,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
struct ServerArgs {
|
||||||
|
#[arg(short, long, help = "Causes the server to fork from the shell")]
|
||||||
|
daemonize: bool,
|
||||||
|
#[arg(short, long, help = "Change the control socket path")]
|
||||||
|
socket_path: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
struct PubKey([u8; 32]);
|
||||||
|
|
||||||
|
impl FromStr for PubKey {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
|
if s.len() != 43 {
|
||||||
|
return Err(format!(
|
||||||
|
"Public key must be exactly 43 characters, got {}",
|
||||||
|
s.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let decoded = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(s)
|
||||||
|
.map_err(|e| format!("Invalid base64 string: {}", e))?;
|
||||||
|
|
||||||
|
if decoded.len() != 32 {
|
||||||
|
return Err(format!(
|
||||||
|
"Decoded key must have a length of 32, got {}",
|
||||||
|
decoded.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut key = [0u8; 32];
|
||||||
|
key.copy_from_slice(&decoded);
|
||||||
|
Ok(PubKey(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for PubKey {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(&base64::engine::general_purpose::STANDARD.encode(self.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PubKey {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
PubKey::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Args)]
|
||||||
|
struct RemoteAddArgs {
|
||||||
|
#[arg(help = "Host ip/domain of a server")]
|
||||||
|
host: String,
|
||||||
|
#[arg(short, long, help = "Chatter port for the server")]
|
||||||
|
port: u8,
|
||||||
|
#[arg(short, long, help = "Pubkey for the remote (skips key exchange)")]
|
||||||
|
pubkey: PubKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug)]
|
||||||
|
#[command(
|
||||||
|
about = "Operations related to remotes (peers)",
|
||||||
|
long_about = "Tunnels are created between remotes"
|
||||||
|
)]
|
||||||
|
enum RemoteCmd {
|
||||||
|
Add(RemoteAddArgs),
|
||||||
|
List,
|
||||||
|
Rm,
|
||||||
|
Info,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Subcommand, Debug)]
|
#[derive(Subcommand, Debug)]
|
||||||
enum Commands {
|
enum Commands {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
Connection(Connection),
|
Connection(Connection),
|
||||||
Remote,
|
#[command(subcommand)]
|
||||||
|
Remote(RemoteCmd),
|
||||||
|
License,
|
||||||
|
Version,
|
||||||
|
Server(ServerArgs),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author = "Solomon W.", version, about = "The tunneling firewall")]
|
#[command(about = "The tunneling firewall")]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
#[arg(long, help = "Run as the swim daemon")]
|
|
||||||
daemonize: bool,
|
|
||||||
#[arg(long, help = "Display the license")]
|
|
||||||
license: bool,
|
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Option<Commands>,
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
static NAME: &str = env!("CARGO_PKG_NAME");
|
||||||
|
static VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
static AUTHOR: &str = env!("CARGO_PKG_AUTHORS");
|
||||||
|
|
||||||
|
fn get_socket_path() -> PathBuf {
|
||||||
|
if std::env::var("XDG_RUNTIME_DIR").is_ok() {
|
||||||
|
std::env::var("XDG_RUNTIME_DIR")
|
||||||
|
.map(|dir| PathBuf::from(dir).join("swim.sock"))
|
||||||
|
.unwrap()
|
||||||
|
} else if std::env::var("SWIM_SYSTEM_DAEMON").is_ok() {
|
||||||
|
PathBuf::from("/run/swim.sock")
|
||||||
|
} else {
|
||||||
|
// Fallback to /tmp
|
||||||
|
PathBuf::from("/tmp/swim.sock")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
if cli.license {
|
match cli.command {
|
||||||
println!("{}", LICENSE);
|
Commands::Connection(_connection) => println!("Hello swim!"),
|
||||||
std::process::exit(0);
|
Commands::License => println!("author(s): {}\n\n{}", AUTHOR, LICENSE),
|
||||||
}
|
Commands::Version => println!("{} {}", NAME, VERSION),
|
||||||
|
Commands::Server(cmd) => {
|
||||||
|
let sock_path = cmd
|
||||||
|
.socket_path
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.or_else(|| Some(get_socket_path()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
if cli.daemonize {
|
let sock = UnixListener::bind(sock_path).expect("Couldn't bind to unix socket!");
|
||||||
let sock = UnixListener::bind(IPC_SOCKET_PATH).expect("Couldn't bind to unix socket!");
|
let (tx, rx) = mpsc::channel(10);
|
||||||
let (tx, rx) = mpsc::channel(10);
|
|
||||||
|
|
||||||
let config = Config::default();
|
let config = Config::default();
|
||||||
|
|
||||||
let app = Arc::new(AppState {
|
let app = Arc::new(AppState {
|
||||||
config: Arc::new(RwLock::new(config)),
|
config: Arc::new(RwLock::new(config)),
|
||||||
notify: tx,
|
notify: tx,
|
||||||
active_config: Arc::new(RwLock::new(Config::default())),
|
active_config: Arc::new(RwLock::new(Config::default())),
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO: Early config init here
|
// TODO: Early config init here
|
||||||
|
|
||||||
let ipc_handle = tokio::spawn(ipc_server(app.clone(), sock));
|
let ipc_handle = tokio::spawn(ipc_server(app.clone(), sock));
|
||||||
let quic_handle = tokio::spawn(quic_server(app.clone(), rx));
|
let quic_handle = tokio::spawn(quic_server(app.clone(), rx));
|
||||||
|
|
||||||
ipc_handle.await.unwrap();
|
ipc_handle.await.unwrap();
|
||||||
quic_handle.await.unwrap();
|
quic_handle.await.unwrap();
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(cmd) = cli.command {
|
|
||||||
match cmd {
|
|
||||||
Commands::Connection(_connection) => println!("Hello swim!"),
|
|
||||||
_ => todo!("Not yet implemented!"),
|
|
||||||
}
|
}
|
||||||
|
_ => todo!("Not yet implemented!"),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user