Skip to content

Commit

Permalink
wireguard: add packet relayer (#4032)
Browse files Browse the repository at this point in the history
* wip

* wip: first step in putting in place forward channels

* Setup event loop for packet relayer

* tuntaskresponse

* wip

* tun task response channel

* Update comment

* done

* formatting

* nits

* Add comment
  • Loading branch information
octol committed Oct 24, 2023
1 parent 85d172e commit d80333c
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 69 deletions.
19 changes: 17 additions & 2 deletions common/wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod active_peers;
mod error;
mod event;
mod network_table;
mod packet_relayer;
mod platform;
mod registered_peers;
mod setup;
Expand Down Expand Up @@ -38,12 +39,26 @@ pub async fn start_wireguard(
let peers_by_tag = Arc::new(std::sync::Mutex::new(wg_tunnel::PeersByTag::new()));

// Start the tun device that is used to relay traffic outbound
let (tun, tun_task_tx) = tun_device::TunDevice::new(peers_by_ip.clone(), peers_by_tag.clone());
let (tun, tun_task_tx, tun_task_response_rx) = tun_device::TunDevice::new(peers_by_ip.clone());
tun.start();

// If we want to have the tun device on a separate host, it's the tun_task and
// tun_task_response channels that needs to be sent over the network to the host where the tun
// device is running.

// The packet relayer's responsibility is to route packets between the correct tunnel and the
// tun device. The tun device may or may not be on a separate host, which is why we can't do
// this routing in the tun device itself.
let (packet_relayer, packet_tx) = packet_relayer::PacketRelayer::new(
tun_task_tx.clone(),
tun_task_response_rx,
peers_by_tag.clone(),
);
packet_relayer.start();

// Start the UDP listener that clients connect to
let udp_listener = udp_listener::WgUdpListener::new(
tun_task_tx,
packet_tx,
peers_by_ip,
peers_by_tag,
Arc::clone(&gateway_client_registry),
Expand Down
66 changes: 66 additions & 0 deletions common/wireguard/src/packet_relayer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::{collections::HashMap, sync::Arc};

use tap::TapFallible;
use tokio::sync::mpsc::{self};

use crate::{
event::Event,
tun_task_channel::{TunTaskResponseRx, TunTaskTx},
};

// The tunnels send packets to the packet relayer, which then relays it to the tun device. And
// conversely, it's where the tun device send responses to, which are relayed back to the correct
// tunnel.
pub(crate) struct PacketRelayer {
// Receive packets from the various tunnels
packet_rx: mpsc::Receiver<(u64, Vec<u8>)>,

// After receive from tunnels, send to the tun device
tun_task_tx: TunTaskTx,

// Receive responses from the tun device
tun_task_response_rx: TunTaskResponseRx,

// After receiving from the tun device, relay back to the correct tunnel
peers_by_tag: Arc<std::sync::Mutex<HashMap<u64, mpsc::UnboundedSender<Event>>>>,
}

impl PacketRelayer {
pub(crate) fn new(
tun_task_tx: TunTaskTx,
tun_task_response_rx: TunTaskResponseRx,
peers_by_tag: Arc<std::sync::Mutex<HashMap<u64, mpsc::UnboundedSender<Event>>>>,
) -> (Self, mpsc::Sender<(u64, Vec<u8>)>) {
let (packet_tx, packet_rx) = mpsc::channel(16);
(
Self {
packet_rx,
tun_task_tx,
tun_task_response_rx,
peers_by_tag,
},
packet_tx,
)
}

pub(crate) async fn run(mut self) {
loop {
tokio::select! {
Some((tag, packet)) = self.packet_rx.recv() => {
log::info!("Sent packet to tun device with tag: {tag}");
self.tun_task_tx.send((tag, packet)).unwrap();
},
Some((tag, packet)) = self.tun_task_response_rx.recv() => {
log::info!("Received response from tun device with tag: {tag}");
self.peers_by_tag.lock().unwrap().get(&tag).and_then(|tx| {
tx.send(Event::Ip(packet.into())).tap_err(|e| log::error!("{e}")).ok()
});
}
}
}
}

pub(crate) fn start(self) {
tokio::spawn(async move { self.run().await });
}
}
96 changes: 53 additions & 43 deletions common/wireguard/src/platform/linux/tun_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::{
event::Event,
setup::{TUN_BASE_NAME, TUN_DEVICE_ADDRESS, TUN_DEVICE_NETMASK},
tun_task_channel::{tun_task_channel, TunTaskPayload, TunTaskRx, TunTaskTx},
tun_task_channel::{
tun_task_channel, tun_task_response_channel, TunTaskPayload, TunTaskResponseRx,
TunTaskResponseTx, TunTaskRx, TunTaskTx,
},
udp_listener::PeersByIp,
wg_tunnel::PeersByTag,
};

fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun {
Expand All @@ -37,19 +39,21 @@ pub struct TunDevice {
// Incoming data that we should send
tun_task_rx: TunTaskRx,

// The routing table.
// An alternative would be to do NAT by just matching incoming with outgoing.
// And when we get replies, this is where we should send it
tun_task_response_tx: TunTaskResponseTx,

// The routing table, as how wireguard does it
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,

// This is an alternative to the routing table, where we just match outgoing source IP with
// incoming destination IP.
nat_table: HashMap<IpAddr, u64>,
peers_by_tag: Arc<std::sync::Mutex<PeersByTag>>,
}

impl TunDevice {
pub fn new(
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_tag: Arc<std::sync::Mutex<PeersByTag>>,
) -> (Self, TunTaskTx) {
) -> (Self, TunTaskTx, TunTaskResponseRx) {
let tun = setup_tokio_tun_device(
format!("{TUN_BASE_NAME}%d").as_str(),
TUN_DEVICE_ADDRESS.parse().unwrap(),
Expand All @@ -59,19 +63,49 @@ impl TunDevice {

// Channels to communicate with the other tasks
let (tun_task_tx, tun_task_rx) = tun_task_channel();
let (tun_task_response_tx, tun_task_response_rx) = tun_task_response_channel();

let tun_device = TunDevice {
tun_task_rx,
tun_task_response_tx,
tun,
peers_by_ip,
nat_table: HashMap::new(),
peers_by_tag,
};

(tun_device, tun_task_tx)
(tun_device, tun_task_tx, tun_task_response_rx)
}

fn handle_tun_read(&self, packet: &[u8]) {
// Send outbound packets out on the wild internet
async fn handle_tun_write(&mut self, data: TunTaskPayload) {
let (tag, packet) = data;
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else {
log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device");
return;
};
let Some(src_addr) = parse_src_address(&packet) else {
log::error!("Unable to parse src_address in packet that was supposed to be written to tun device");
return;
};
log::info!(
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
packet.len()
);

// TODO: expire old entries
self.nat_table.insert(src_addr, tag);

self.tun
.write_all(&packet)
.await
.tap_err(|err| {
log::error!("iface: write error: {err}");
})
.ok();
}

// Receive reponse packets from the wild internet
async fn handle_tun_read(&self, packet: &[u8]) {
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(packet) else {
log::error!("Unable to parse dst_address in packet that was read from tun device");
return;
Expand All @@ -86,6 +120,8 @@ impl TunDevice {
);

// Route packet to the correct peer.

// This is how wireguard does it, by consulting the AllowedIPs table.
if false {
let Ok(peers) = self.peers_by_ip.lock() else {
log::error!("Failed to lock peers_by_ip, aborting tun device read");
Expand All @@ -101,48 +137,22 @@ impl TunDevice {
}
}

// But we do it by consulting the NAT table.
{
if let Some(tag) = self.nat_table.get(&dst_addr) {
log::info!("Forward packet to wg tunnel with tag: {tag}");
self.peers_by_tag.lock().unwrap().get(tag).and_then(|tx| {
tx.send(Event::Ip(packet.to_vec().into()))
.tap_err(|err| log::error!("{err}"))
.ok()
});
self.tun_task_response_tx
.send((*tag, packet.to_vec()))
.await
.tap_err(|err| log::error!("{err}"))
.ok();
return;
}
}

log::info!("No peer found, packet dropped");
}

async fn handle_tun_write(&mut self, data: TunTaskPayload) {
let (tag, packet) = data;
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else {
log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device");
return;
};
let Some(src_addr) = parse_src_address(&packet) else {
log::error!("Unable to parse src_address in packet that was supposed to be written to tun device");
return;
};
log::info!(
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
packet.len()
);

// TODO: expire old entries
self.nat_table.insert(src_addr, tag);

self.tun
.write_all(&packet)
.await
.tap_err(|err| {
log::error!("iface: write error: {err}");
})
.ok();
}

pub async fn run(mut self) {
let mut buf = [0u8; 1024];

Expand All @@ -152,7 +162,7 @@ impl TunDevice {
len = self.tun.read(&mut buf) => match len {
Ok(len) => {
let packet = &buf[..len];
self.handle_tun_read(packet);
self.handle_tun_read(packet).await;
},
Err(err) => {
log::info!("iface: read error: {err}");
Expand Down
34 changes: 31 additions & 3 deletions common/wireguard/src/tun_task_channel.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use tokio::sync::mpsc;

pub(crate) type TunTaskPayload = (u64, Vec<u8>);

#[derive(Clone)]
pub struct TunTaskTx(tokio::sync::mpsc::UnboundedSender<TunTaskPayload>);

pub(crate) struct TunTaskRx(tokio::sync::mpsc::UnboundedReceiver<TunTaskPayload>);
pub struct TunTaskTx(mpsc::UnboundedSender<TunTaskPayload>);
pub(crate) struct TunTaskRx(mpsc::UnboundedReceiver<TunTaskPayload>);

impl TunTaskTx {
pub(crate) fn send(
Expand All @@ -24,3 +25,30 @@ pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) {
let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::unbounded_channel();
(TunTaskTx(tun_task_tx), TunTaskRx(tun_task_rx))
}

// Send responses back from the tun device back to the PacketRelayer
pub(crate) struct TunTaskResponseTx(mpsc::Sender<TunTaskPayload>);
pub struct TunTaskResponseRx(mpsc::Receiver<TunTaskPayload>);

impl TunTaskResponseTx {
pub(crate) async fn send(
&self,
data: TunTaskPayload,
) -> Result<(), tokio::sync::mpsc::error::SendError<TunTaskPayload>> {
self.0.send(data).await
}
}

impl TunTaskResponseRx {
pub(crate) async fn recv(&mut self) -> Option<TunTaskPayload> {
self.0.recv().await
}
}

pub(crate) fn tun_task_response_channel() -> (TunTaskResponseTx, TunTaskResponseRx) {
let (tun_task_tx, tun_task_rx) = tokio::sync::mpsc::channel(16);
(
TunTaskResponseTx(tun_task_tx),
TunTaskResponseRx(tun_task_rx),
)
}
11 changes: 6 additions & 5 deletions common/wireguard/src/udp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::{
network_table::NetworkTable,
registered_peers::{RegisteredPeer, RegisteredPeers},
setup::{self, WG_ADDRESS, WG_PORT},
tun_task_channel::TunTaskTx,
wg_tunnel::PeersByTag,
};

Expand Down Expand Up @@ -65,7 +64,8 @@ pub struct WgUdpListener {
udp: Arc<UdpSocket>,

// Send data to the TUN device for sending
tun_task_tx: TunTaskTx,
// tun_task_tx: TunTaskTx,
packet_tx: mpsc::Sender<(u64, Vec<u8>)>,

// Wireguard rate limiter
rate_limiter: RateLimiter,
Expand All @@ -75,7 +75,7 @@ pub struct WgUdpListener {

impl WgUdpListener {
pub async fn new(
tun_task_tx: TunTaskTx,
packet_tx: mpsc::Sender<(u64, Vec<u8>)>,
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
peers_by_tag: Arc<std::sync::Mutex<PeersByTag>>,
gateway_client_registry: Arc<GatewayClientRegistry>,
Expand All @@ -101,7 +101,7 @@ impl WgUdpListener {
peers_by_ip,
peers_by_tag,
udp,
tun_task_tx,
packet_tx,
rate_limiter,
gateway_client_registry,
})
Expand Down Expand Up @@ -207,7 +207,8 @@ impl WgUdpListener {
*registered_peer.public_key,
registered_peer.index,
registered_peer.allowed_ips,
self.tun_task_tx.clone(),
// self.tun_task_tx.clone(),
self.packet_tx.clone(),
);

self.peers_by_ip.lock().unwrap().insert(registered_peer.allowed_ips, peer_tx.clone());
Expand Down
Loading

0 comments on commit d80333c

Please sign in to comment.