Skip to content

Commit

Permalink
wip: factorize cmsg API using a helper wrapper struct
Browse files Browse the repository at this point in the history
  • Loading branch information
stormshield-damiend committed Nov 15, 2023
1 parent 147af5f commit c9b82fc
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 124 deletions.
1 change: 1 addition & 0 deletions quinn-udp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod cmsg;
#[path = "unix.rs"]
mod imp;

// FIXME rename and add CmsgHelper in unix.rs to factorize
#[cfg(windows)]
#[path = "wsa_cmsg.rs"]
mod cmsg;
Expand Down
137 changes: 128 additions & 9 deletions quinn-udp/src/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn send(
for transmit in transmits {
// we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access
// to the inner field which holds the WSAMSG
let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
let mut ctrl_buf = Aligned([0; CMSG_LEN]);
let daddr = socket2::SockAddr::from(transmit.destination);

let mut data = WinSock::WSABUF {
Expand All @@ -173,8 +173,10 @@ fn send(
dwFlags: 0,
};

let mut helper = unsafe { CmsgHelper::new(&mut wsa_msg) };

// Add control messages (ECN and PKTINFO)
let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) };
let mut encoder = unsafe { cmsg::Encoder::new(&mut helper) };

if let Some(ip) = transmit.src_ip {
let ip = std::net::SocketAddr::new(ip, 0);
Expand Down Expand Up @@ -253,7 +255,7 @@ fn recv(
.expect("Valid function pointer for WSARecvMsg");

// we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG
let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
let mut ctrl_buf = Aligned([0; CMSG_LEN]);
let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() };
let mut data = WinSock::WSABUF {
buf: bufs[0].as_mut_ptr(),
Expand Down Expand Up @@ -303,27 +305,28 @@ fn recv(
let mut ecn_bits = 0;
let mut dst_ip = None;

let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) };
for cmsg in cmsg_iter {
let helper = unsafe { CmsgHelper::new(&mut wsa_msg) };
for cmsg in helper {
// [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding]
match (cmsg.cmsg_level, cmsg.cmsg_type) {
(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => {
let pktinfo = unsafe { cmsg::decode::<WinSock::IN_PKTINFO>(cmsg) };
let pktinfo = unsafe { CmsgHelper::decode::<WinSock::IN_PKTINFO>(cmsg) };
// Addr is stored in big endian format
let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr }));
dst_ip = Some(ip4.into());
}
(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => {
let pktinfo = unsafe { cmsg::decode::<WinSock::IN6_PKTINFO>(cmsg) };
let pktinfo = unsafe { CmsgHelper::decode::<WinSock::IN6_PKTINFO>(cmsg) };
// Addr is stored in big endian format
dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte }));
}
(WinSock::IPPROTO_IP, WinSock::IP_ECN) => {
// ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
ecn_bits = unsafe { cmsg::decode::<i32>(cmsg) };
ecn_bits = unsafe { CmsgHelper::decode::<i32>(cmsg) };
}
(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => {
// ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
ecn_bits = unsafe { cmsg::decode::<i32>(cmsg) };
ecn_bits = unsafe { CmsgHelper::decode::<i32>(cmsg) };
}
_ => {}
}
Expand Down Expand Up @@ -396,3 +399,119 @@ fn set_socket_option(
}

const OPTION_ON: u32 = 1;

#[derive(Copy, Clone)]
#[repr(align(8))] // Conservative bound for align_of<WinSock::CMSGHDR>
struct Aligned<T>(pub(crate) T);

/// Cmsg Helper wrapping [`WinSock::WSAMSG`] and [`WinSock::CMSGHDR`]
pub(crate) struct CmsgHelper<'a> {
hdr: &'a mut WinSock::WSAMSG,
cmsg: Option<&'a mut WinSock::CMSGHDR>,
}

impl<'a> CmsgHelper<'a> {
/// # Safety
/// - `hdr.Control.buf` must be a suitably aligned pointer to `hdr.Control.len` bytes that
/// can be safely written
pub(crate) unsafe fn new(hdr: &'a mut WinSock::WSAMSG) -> Self {
Self {
cmsg: Self::cmsg_firsthdr(hdr).as_mut(),
hdr,
}
}

pub(crate) fn control_len(&self) -> usize {
self.hdr.Control.len as usize
}

pub(crate) fn set_control_len(&mut self, len: usize) {
self.hdr.Control.len = len as _;
}

pub(crate) fn cmsg_take(&mut self) -> Option<&'a mut WinSock::CMSGHDR> {
self.cmsg.take()
}

pub(crate) fn set_cmsg(&mut self, cmsg: Option<&'a mut WinSock::CMSGHDR>) {
self.cmsg = cmsg;
}

pub(crate) fn cmsghdr_align_of() -> usize {
mem::align_of::<WinSock::CMSGHDR>()
}

/// # Safety
///
/// `cmsg` must refer to a [`WinSock::CMSGHDR`] containing a payload of type `T`
unsafe fn decode<T: Copy>(cmsg: &WinSock::CMSGHDR) -> T {
assert!(mem::align_of::<T>() <= mem::align_of::<WinSock::CMSGHDR>());
debug_assert_eq!(
cmsg.cmsg_len,
CmsgHelper::cmsg_len(mem::size_of::<T>() as _)
);
ptr::read(CmsgHelper::cmsg_data(cmsg) as *const T)
}

// Helpers functions based on C macros from
// https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741
pub(crate) fn cmsghdr_align(length: usize) -> usize {
(length + mem::align_of::<WinSock::CMSGHDR>() - 1)
& !(mem::align_of::<WinSock::CMSGHDR>() - 1)
}

pub(crate) fn cmsgdata_align(length: usize) -> usize {
(length + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1)
}

pub(crate) unsafe fn cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR {
if (*msg).Control.len as usize >= mem::size_of::<WinSock::CMSGHDR>() {
(*msg).Control.buf as *mut WinSock::CMSGHDR
} else {
ptr::null_mut::<WinSock::CMSGHDR>()
}
}

pub(crate) unsafe fn cmsg_nxthdr(
&self,
cmsg: *const WinSock::CMSGHDR,
) -> *mut WinSock::CMSGHDR {
if cmsg.is_null() {
return Self::cmsg_firsthdr(self.hdr);
}
let next = (cmsg as usize + Self::cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR;
let max = self.hdr.Control.buf as usize + self.hdr.Control.len as usize;
if (next.offset(1)) as usize > max {
ptr::null_mut()
} else {
next
}
}

pub(crate) unsafe fn cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 {
(cmsg as usize + Self::cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>())) as *mut u8
}

pub(crate) fn cmsg_space(length: usize) -> usize {
Self::cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>() + Self::cmsghdr_align(length))
}

pub(crate) fn cmsg_len(length: usize) -> usize {
Self::cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>()) + length
}
}

impl<'a> Iterator for CmsgHelper<'a> {
type Item = &'a WinSock::CMSGHDR;

/// # Safety
///
/// `self.hdr.Control.buf` must point to memory outliving `'a` which can be soundly read for the
/// lifetime of the constructed `Iter` and contains a buffer of [`WinSock::CMSGHDR`], i.e.
/// is aligned for [`WinSock::CMSGHDR`], is fully initialized, and has correct internal links.
fn next(&mut self) -> Option<&'a WinSock::CMSGHDR> {
let current = self.cmsg.take()?;
self.cmsg = unsafe { CmsgHelper::cmsg_nxthdr(self, current).as_mut() };
Some(current)
}
}
144 changes: 29 additions & 115 deletions quinn-udp/src/wsa_cmsg.rs
Original file line number Diff line number Diff line change
@@ -1,103 +1,54 @@
use std::{mem, ptr};
use std::{ffi::c_int, mem, ptr};

use windows_sys::Win32::Networking::WinSock;
use crate::imp::CmsgHelper;

#[derive(Copy, Clone)]
#[repr(align(8))] // Conservative bound for align_of<WinSock::CMSGHDR>
pub(crate) struct Aligned<T>(pub(crate) T);

// Helpers functions based on C macros from
// https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741
fn wsa_cmsghdr_align(length: usize) -> usize {
(length + mem::align_of::<WinSock::CMSGHDR>() - 1) & !(mem::align_of::<WinSock::CMSGHDR>() - 1)
}

fn wsa_cmsgdata_align(length: usize) -> usize {
(length + mem::align_of::<usize>() - 1) & !(mem::align_of::<usize>() - 1)
}

unsafe fn wsa_cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR {
if (*msg).Control.len as usize >= mem::size_of::<WinSock::CMSGHDR>() {
(*msg).Control.buf as *mut WinSock::CMSGHDR
} else {
ptr::null_mut::<WinSock::CMSGHDR>()
}
}

unsafe fn wsa_cmsg_nxthdr(
msg: *const WinSock::WSAMSG,
cmsg: *const WinSock::CMSGHDR,
) -> *mut WinSock::CMSGHDR {
if cmsg.is_null() {
return wsa_cmsg_firsthdr(msg);
}
let next = (cmsg as usize + wsa_cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR;
let max = (*msg).Control.buf as usize + (*msg).Control.len as usize;
if (next.offset(1)) as usize > max {
ptr::null_mut()
} else {
next
}
}

unsafe fn wsa_cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 {
(cmsg as usize + wsa_cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>())) as *mut u8
}

fn wsa_cmsg_space(length: usize) -> usize {
wsa_cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>() + wsa_cmsghdr_align(length))
}

fn wsa_cmsg_len(length: usize) -> usize {
wsa_cmsgdata_align(mem::size_of::<WinSock::CMSGHDR>()) + length
}

/// Helper to encode a series of control messages ("cmsgs") to a buffer for use in `WSASendMsg`.
/// Helper to encode a series of control messages ("cmsgs") to a buffer for use in a `sendmsg`` like function
///
/// The operation must be "finished" for the `WSAMSG`` to be usable, either by calling `finish`
/// The operation must be "finished" for the message to be usable, either by calling `finish`
/// explicitly or by dropping the `Encoder`.
pub(crate) struct Encoder<'a> {
hdr: &'a mut WinSock::WSAMSG,
cmsg: Option<&'a mut WinSock::CMSGHDR>,
helper: &'a mut CmsgHelper<'a>,
len: usize,
}

impl<'a> Encoder<'a> {
/// # Safety
/// - `hdr.Control.buf` must be a suitably aligned pointer to `hdr.Control.len` bytes that
/// can be safely written
/// - The `Encoder` must be dropped before `hdr` is passed to a system call, and must not be leaked.
pub(crate) unsafe fn new(hdr: &'a mut WinSock::WSAMSG) -> Self {
Self {
cmsg: wsa_cmsg_firsthdr(hdr).as_mut(),
hdr,
len: 0,
}
/// - The `CmsgHelper` handles all the alignement constraints
/// - The `Encoder` must be dropped before the native build message is passed to a system call,
/// and must not be leaked.
pub(crate) unsafe fn new(helper: &'a mut CmsgHelper<'a>) -> Self {
Self { helper, len: 0 }
}

/// Append a control message ([`WinSock::CMSGHDR`]) to the buffer.
/// Append a native control message to the buffer.
///
/// # Panics
/// - If insufficient buffer space remains.
/// - If `T` has stricter alignment requirements than `cmsghdr`
pub(crate) fn push<T: Copy + ?Sized>(&mut self, level: i32, ty: i32, value: T) {
assert!(mem::align_of::<T>() <= mem::align_of::<WinSock::CMSGHDR>());
let space = wsa_cmsg_space(mem::size_of_val(&value) as _);
/// - If `T` has stricter alignment requirements than the native type
/// - level and type fields of cmsg must of a type compatible with [`std::ffi::c_int`]`
pub(crate) fn push<T: Copy + ?Sized>(&mut self, level: c_int, ty: c_int, value: T) {
assert!(mem::align_of::<T>() <= CmsgHelper::cmsghdr_align_of());
let space = CmsgHelper::cmsg_space(mem::size_of_val(&value) as _);
assert!(
self.hdr.Control.len as usize >= self.len + space,
self.helper.control_len() >= self.len + space,
"control message buffer too small. Required: {}, Available: {}",
self.len + space,
self.hdr.Control.len
self.helper.control_len()
);
let cmsg = self.cmsg.take().expect("no control buffer space remaining");
let cmsg = self
.helper
.cmsg_take()
.expect("no control buffer space remaining");
cmsg.cmsg_level = level;
cmsg.cmsg_type = ty;
cmsg.cmsg_len = wsa_cmsg_len(mem::size_of_val(&value) as _) as _;
cmsg.cmsg_len = CmsgHelper::cmsg_len(mem::size_of_val(&value) as _) as _;
unsafe {
ptr::write(wsa_cmsg_data(cmsg) as *const T as *mut T, value);
ptr::write(CmsgHelper::cmsg_data(cmsg) as *const T as *mut T, value);
}
self.len += space;
self.cmsg = unsafe { wsa_cmsg_nxthdr(self.hdr, cmsg).as_mut() };

self.helper
.set_cmsg(unsafe { CmsgHelper::cmsg_nxthdr(self.helper, cmsg).as_mut() });
}

/// Finishes appending control messages to the buffer
Expand All @@ -107,46 +58,9 @@ impl<'a> Encoder<'a> {
}

// Statically guarantees that the encoding operation is "finished" before the control buffer is read
// by `WSASendMsg`.
// by sendmsg like functions.
impl<'a> Drop for Encoder<'a> {
fn drop(&mut self) {
self.hdr.Control.len = self.len as _;
}
}

/// # Safety
///
/// `cmsg` must refer to a [`WinSock::CMSGHDR`] containing a payload of type `T`
pub(crate) unsafe fn decode<T: Copy>(cmsg: &WinSock::CMSGHDR) -> T {
assert!(mem::align_of::<T>() <= mem::align_of::<WinSock::CMSGHDR>());
debug_assert_eq!(cmsg.cmsg_len, wsa_cmsg_len(mem::size_of::<T>() as _));
ptr::read(wsa_cmsg_data(cmsg) as *const T)
}

pub(crate) struct Iter<'a> {
hdr: &'a WinSock::WSAMSG,
cmsg: Option<&'a WinSock::CMSGHDR>,
}

impl<'a> Iter<'a> {
/// # Safety
///
/// `hdr.Control.buf` must point to memory outliving `'a` which can be soundly read for the
/// lifetime of the constructed `Iter` and contains a buffer of [`WinSock::CMSGHDR`], i.e.
/// is aligned for [`WinSock::CMSGHDR`], is fully initialized, and has correct internal links.
pub(crate) unsafe fn new(hdr: &'a WinSock::WSAMSG) -> Self {
Self {
hdr,
cmsg: wsa_cmsg_firsthdr(hdr).as_ref(),
}
}
}

impl<'a> Iterator for Iter<'a> {
type Item = &'a WinSock::CMSGHDR;
fn next(&mut self) -> Option<&'a WinSock::CMSGHDR> {
let current = self.cmsg.take()?;
self.cmsg = unsafe { wsa_cmsg_nxthdr(self.hdr, current).as_ref() };
Some(current)
self.helper.set_control_len(self.len);
}
}

0 comments on commit c9b82fc

Please sign in to comment.