diff --git a/common/wireguard/src/lib.rs b/common/wireguard/src/lib.rs index 813ba4718f..cd13ba6937 100644 --- a/common/wireguard/src/lib.rs +++ b/common/wireguard/src/lib.rs @@ -7,6 +7,7 @@ mod active_peers; mod error; mod event; mod network_table; +mod packet_relayer; mod platform; mod registered_peers; mod setup; @@ -38,12 +39,26 @@ pub async fn start_wireguard( let peers_by_tag = Arc::new(std::sync::Mutex::new(wg_tunnel::PeersByTag::new())); // Start the tun device that is used to relay traffic outbound - let (tun, tun_task_tx) = tun_device::TunDevice::new(peers_by_ip.clone(), peers_by_tag.clone()); + let (tun, tun_task_tx, tun_task_response_rx) = tun_device::TunDevice::new(peers_by_ip.clone()); tun.start(); + // If we want to have the tun device on a separate host, it's the tun_task and + // tun_task_response channels that needs to be sent over the network to the host where the tun + // device is running. + + // The packet relayer's responsibility is to route packets between the correct tunnel and the + // tun device. The tun device may or may not be on a separate host, which is why we can't do + // this routing in the tun device itself. + let (packet_relayer, packet_tx) = packet_relayer::PacketRelayer::new( + tun_task_tx.clone(), + tun_task_response_rx, + peers_by_tag.clone(), + ); + packet_relayer.start(); + // Start the UDP listener that clients connect to let udp_listener = udp_listener::WgUdpListener::new( - tun_task_tx, + packet_tx, peers_by_ip, peers_by_tag, Arc::clone(&gateway_client_registry), diff --git a/common/wireguard/src/packet_relayer.rs b/common/wireguard/src/packet_relayer.rs new file mode 100644 index 0000000000..8f0a4ce1ac --- /dev/null +++ b/common/wireguard/src/packet_relayer.rs @@ -0,0 +1,66 @@ +use std::{collections::HashMap, sync::Arc}; + +use tap::TapFallible; +use tokio::sync::mpsc::{self}; + +use crate::{ + event::Event, + tun_task_channel::{TunTaskResponseRx, TunTaskTx}, +}; + +// The tunnels send packets to the packet relayer, which then relays it to the tun device. And +// conversely, it's where the tun device send responses to, which are relayed back to the correct +// tunnel. +pub(crate) struct PacketRelayer { + // Receive packets from the various tunnels + packet_rx: mpsc::Receiver<(u64, Vec)>, + + // After receive from tunnels, send to the tun device + tun_task_tx: TunTaskTx, + + // Receive responses from the tun device + tun_task_response_rx: TunTaskResponseRx, + + // After receiving from the tun device, relay back to the correct tunnel + peers_by_tag: Arc>>>, +} + +impl PacketRelayer { + pub(crate) fn new( + tun_task_tx: TunTaskTx, + tun_task_response_rx: TunTaskResponseRx, + peers_by_tag: Arc>>>, + ) -> (Self, mpsc::Sender<(u64, Vec)>) { + let (packet_tx, packet_rx) = mpsc::channel(16); + ( + Self { + packet_rx, + tun_task_tx, + tun_task_response_rx, + peers_by_tag, + }, + packet_tx, + ) + } + + pub(crate) async fn run(mut self) { + loop { + tokio::select! { + Some((tag, packet)) = self.packet_rx.recv() => { + log::info!("Sent packet to tun device with tag: {tag}"); + self.tun_task_tx.send((tag, packet)).unwrap(); + }, + Some((tag, packet)) = self.tun_task_response_rx.recv() => { + log::info!("Received response from tun device with tag: {tag}"); + self.peers_by_tag.lock().unwrap().get(&tag).and_then(|tx| { + tx.send(Event::Ip(packet.into())).tap_err(|e| log::error!("{e}")).ok() + }); + } + } + } + } + + pub(crate) fn start(self) { + tokio::spawn(async move { self.run().await }); + } +} diff --git a/common/wireguard/src/platform/linux/tun_device.rs b/common/wireguard/src/platform/linux/tun_device.rs index c714eb6cdf..cc449112bc 100644 --- a/common/wireguard/src/platform/linux/tun_device.rs +++ b/common/wireguard/src/platform/linux/tun_device.rs @@ -11,9 +11,11 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::{ event::Event, setup::{TUN_BASE_NAME, TUN_DEVICE_ADDRESS, TUN_DEVICE_NETMASK}, - tun_task_channel::{tun_task_channel, TunTaskPayload, TunTaskRx, TunTaskTx}, + tun_task_channel::{ + tun_task_channel, tun_task_response_channel, TunTaskPayload, TunTaskResponseRx, + TunTaskResponseTx, TunTaskRx, TunTaskTx, + }, udp_listener::PeersByIp, - wg_tunnel::PeersByTag, }; fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun { @@ -37,19 +39,21 @@ pub struct TunDevice { // Incoming data that we should send tun_task_rx: TunTaskRx, - // The routing table. - // An alternative would be to do NAT by just matching incoming with outgoing. + // And when we get replies, this is where we should send it + tun_task_response_tx: TunTaskResponseTx, + + // The routing table, as how wireguard does it peers_by_ip: Arc>, + // This is an alternative to the routing table, where we just match outgoing source IP with + // incoming destination IP. nat_table: HashMap, - peers_by_tag: Arc>, } impl TunDevice { pub fn new( peers_by_ip: Arc>, - peers_by_tag: Arc>, - ) -> (Self, TunTaskTx) { + ) -> (Self, TunTaskTx, TunTaskResponseRx) { let tun = setup_tokio_tun_device( format!("{TUN_BASE_NAME}%d").as_str(), TUN_DEVICE_ADDRESS.parse().unwrap(), @@ -59,19 +63,49 @@ impl TunDevice { // Channels to communicate with the other tasks let (tun_task_tx, tun_task_rx) = tun_task_channel(); + let (tun_task_response_tx, tun_task_response_rx) = tun_task_response_channel(); let tun_device = TunDevice { tun_task_rx, + tun_task_response_tx, tun, peers_by_ip, nat_table: HashMap::new(), - peers_by_tag, }; - (tun_device, tun_task_tx) + (tun_device, tun_task_tx, tun_task_response_rx) } - fn handle_tun_read(&self, packet: &[u8]) { + // Send outbound packets out on the wild internet + async fn handle_tun_write(&mut self, data: TunTaskPayload) { + let (tag, packet) = data; + let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else { + log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device"); + return; + }; + let Some(src_addr) = parse_src_address(&packet) else { + log::error!("Unable to parse src_address in packet that was supposed to be written to tun device"); + return; + }; + log::info!( + "iface: write Packet({src_addr} -> {dst_addr}, {} bytes)", + packet.len() + ); + + // TODO: expire old entries + self.nat_table.insert(src_addr, tag); + + self.tun + .write_all(&packet) + .await + .tap_err(|err| { + log::error!("iface: write error: {err}"); + }) + .ok(); + } + + // Receive reponse packets from the wild internet + async fn handle_tun_read(&self, packet: &[u8]) { let Some(dst_addr) = boringtun::noise::Tunn::dst_address(packet) else { log::error!("Unable to parse dst_address in packet that was read from tun device"); return; @@ -86,6 +120,8 @@ impl TunDevice { ); // Route packet to the correct peer. + + // This is how wireguard does it, by consulting the AllowedIPs table. if false { let Ok(peers) = self.peers_by_ip.lock() else { log::error!("Failed to lock peers_by_ip, aborting tun device read"); @@ -101,14 +137,15 @@ impl TunDevice { } } + // But we do it by consulting the NAT table. { if let Some(tag) = self.nat_table.get(&dst_addr) { log::info!("Forward packet to wg tunnel with tag: {tag}"); - self.peers_by_tag.lock().unwrap().get(tag).and_then(|tx| { - tx.send(Event::Ip(packet.to_vec().into())) - .tap_err(|err| log::error!("{err}")) - .ok() - }); + self.tun_task_response_tx + .send((*tag, packet.to_vec())) + .await + .tap_err(|err| log::error!("{err}")) + .ok(); return; } } @@ -116,33 +153,6 @@ impl TunDevice { log::info!("No peer found, packet dropped"); } - async fn handle_tun_write(&mut self, data: TunTaskPayload) { - let (tag, packet) = data; - let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else { - log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device"); - return; - }; - let Some(src_addr) = parse_src_address(&packet) else { - log::error!("Unable to parse src_address in packet that was supposed to be written to tun device"); - return; - }; - log::info!( - "iface: write Packet({src_addr} -> {dst_addr}, {} bytes)", - packet.len() - ); - - // TODO: expire old entries - self.nat_table.insert(src_addr, tag); - - self.tun - .write_all(&packet) - .await - .tap_err(|err| { - log::error!("iface: write error: {err}"); - }) - .ok(); - } - pub async fn run(mut self) { let mut buf = [0u8; 1024]; @@ -152,7 +162,7 @@ impl TunDevice { len = self.tun.read(&mut buf) => match len { Ok(len) => { let packet = &buf[..len]; - self.handle_tun_read(packet); + self.handle_tun_read(packet).await; }, Err(err) => { log::info!("iface: read error: {err}"); diff --git a/common/wireguard/src/tun_task_channel.rs b/common/wireguard/src/tun_task_channel.rs index 423243d27e..3ecc6a76b4 100644 --- a/common/wireguard/src/tun_task_channel.rs +++ b/common/wireguard/src/tun_task_channel.rs @@ -1,9 +1,10 @@ +use tokio::sync::mpsc; + pub(crate) type TunTaskPayload = (u64, Vec); #[derive(Clone)] -pub struct TunTaskTx(tokio::sync::mpsc::UnboundedSender); - -pub(crate) struct TunTaskRx(tokio::sync::mpsc::UnboundedReceiver); +pub struct TunTaskTx(mpsc::UnboundedSender); +pub(crate) struct TunTaskRx(mpsc::UnboundedReceiver); impl TunTaskTx { pub(crate) fn send( @@ -24,3 +25,30 @@ pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) { let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::unbounded_channel(); (TunTaskTx(tun_task_tx), TunTaskRx(tun_task_rx)) } + +// Send responses back from the tun device back to the PacketRelayer +pub(crate) struct TunTaskResponseTx(mpsc::Sender); +pub struct TunTaskResponseRx(mpsc::Receiver); + +impl TunTaskResponseTx { + pub(crate) async fn send( + &self, + data: TunTaskPayload, + ) -> Result<(), tokio::sync::mpsc::error::SendError> { + self.0.send(data).await + } +} + +impl TunTaskResponseRx { + pub(crate) async fn recv(&mut self) -> Option { + self.0.recv().await + } +} + +pub(crate) fn tun_task_response_channel() -> (TunTaskResponseTx, TunTaskResponseRx) { + let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::channel(16); + ( + TunTaskResponseTx(tun_task_tx), + TunTaskResponseRx(tun_task_rx), + ) +} diff --git a/common/wireguard/src/udp_listener.rs b/common/wireguard/src/udp_listener.rs index e06e0feb20..c703468d97 100644 --- a/common/wireguard/src/udp_listener.rs +++ b/common/wireguard/src/udp_listener.rs @@ -24,7 +24,6 @@ use crate::{ network_table::NetworkTable, registered_peers::{RegisteredPeer, RegisteredPeers}, setup::{self, WG_ADDRESS, WG_PORT}, - tun_task_channel::TunTaskTx, wg_tunnel::PeersByTag, }; @@ -65,7 +64,8 @@ pub struct WgUdpListener { udp: Arc, // Send data to the TUN device for sending - tun_task_tx: TunTaskTx, + // tun_task_tx: TunTaskTx, + packet_tx: mpsc::Sender<(u64, Vec)>, // Wireguard rate limiter rate_limiter: RateLimiter, @@ -75,7 +75,7 @@ pub struct WgUdpListener { impl WgUdpListener { pub async fn new( - tun_task_tx: TunTaskTx, + packet_tx: mpsc::Sender<(u64, Vec)>, peers_by_ip: Arc>, peers_by_tag: Arc>, gateway_client_registry: Arc, @@ -101,7 +101,7 @@ impl WgUdpListener { peers_by_ip, peers_by_tag, udp, - tun_task_tx, + packet_tx, rate_limiter, gateway_client_registry, }) @@ -207,7 +207,8 @@ impl WgUdpListener { *registered_peer.public_key, registered_peer.index, registered_peer.allowed_ips, - self.tun_task_tx.clone(), + // self.tun_task_tx.clone(), + self.packet_tx.clone(), ); self.peers_by_ip.lock().unwrap().insert(registered_peer.allowed_ips, peer_tx.clone()); diff --git a/common/wireguard/src/wg_tunnel.rs b/common/wireguard/src/wg_tunnel.rs index f9308fbe2c..32ece1e842 100644 --- a/common/wireguard/src/wg_tunnel.rs +++ b/common/wireguard/src/wg_tunnel.rs @@ -7,6 +7,7 @@ use boringtun::{ }; use bytes::Bytes; use log::{debug, error, info, warn}; +use rand::RngCore; use tap::TapFallible; use tokio::{ net::UdpSocket, @@ -14,10 +15,7 @@ use tokio::{ time::timeout, }; -use crate::{ - error::WgError, event::Event, network_table::NetworkTable, registered_peers::PeerIdx, - tun_task_channel::TunTaskTx, -}; +use crate::{error::WgError, event::Event, network_table::NetworkTable, registered_peers::PeerIdx}; const HANDSHAKE_MAX_RATE: u64 = 10; @@ -47,7 +45,8 @@ pub struct WireGuardTunnel { close_rx: broadcast::Receiver<()>, // Send data to the task that handles sending data through the tun device - tun_task_tx: TunTaskTx, + // tun_task_tx: TunTaskTx, + packet_tx: mpsc::Sender<(u64, Vec)>, tag: u64, } @@ -68,7 +67,7 @@ impl WireGuardTunnel { index: PeerIdx, peer_allowed_ips: ip_network::IpNetwork, // rate_limiter: Option, - tunnel_tx: TunTaskTx, + packet_tx: mpsc::Sender<(u64, Vec)>, ) -> (Self, mpsc::UnboundedSender, u64) { let local_addr = udp.local_addr().unwrap(); let peer_addr = udp.peer_addr(); @@ -92,7 +91,7 @@ impl WireGuardTunnel { index, rate_limiter, ) - .unwrap(), + .expect("failed to create Tunn instance"), )); // Channels with incoming data that is received by the main event loop @@ -104,10 +103,7 @@ impl WireGuardTunnel { let mut allowed_ips = NetworkTable::new(); allowed_ips.insert(peer_allowed_ips, ()); - // random u64 - use rand::RngCore; - let mut rng = rand::rngs::OsRng; - let tag = rng.next_u64(); + let tag = Self::new_tag(); let tunnel = WireGuardTunnel { peer_rx, @@ -117,13 +113,18 @@ impl WireGuardTunnel { wg_tunnel, close_tx, close_rx, - tun_task_tx: tunnel_tx, + packet_tx, tag, }; (tunnel, peer_tx, tag) } + fn new_tag() -> u64 { + // TODO: check for collisions + rand::thread_rng().next_u64() + } + fn close(&self) { let _ = self.close_tx.send(()); } @@ -210,14 +211,20 @@ impl WireGuardTunnel { } TunnResult::WriteToTunnelV4(packet, addr) => { if self.allowed_ips.longest_match(addr).is_some() { - self.tun_task_tx.send((self.tag, packet.to_vec())).unwrap(); + self.packet_tx + .send((self.tag, packet.to_vec())) + .await + .unwrap(); } else { warn!("Packet from {addr} not in allowed_ips"); } } TunnResult::WriteToTunnelV6(packet, addr) => { if self.allowed_ips.longest_match(addr).is_some() { - self.tun_task_tx.send((self.tag, packet.to_vec())).unwrap(); + self.packet_tx + .send((self.tag, packet.to_vec())) + .await + .unwrap(); } else { warn!("Packet (v6) from {addr} not in allowed_ips"); } @@ -319,7 +326,7 @@ pub(crate) fn start_wg_tunnel( peer_static_public: x25519::PublicKey, peer_index: PeerIdx, peer_allowed_ips: ip_network::IpNetwork, - tunnel_tx: TunTaskTx, + packet_tx: mpsc::Sender<(u64, Vec)>, ) -> ( tokio::task::JoinHandle, mpsc::UnboundedSender, @@ -332,7 +339,7 @@ pub(crate) fn start_wg_tunnel( peer_static_public, peer_index, peer_allowed_ips, - tunnel_tx, + packet_tx, ); let join_handle = tokio::spawn(async move { tunnel.spin_off().await;