diff --git a/quic/s2n-quic-platform/src/message.rs b/quic/s2n-quic-platform/src/message.rs index 373a5201f..6d67f0b0e 100644 --- a/quic/s2n-quic-platform/src/message.rs +++ b/quic/s2n-quic-platform/src/message.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use core::ffi::c_void; +use core::{cell::UnsafeCell, ffi::c_void, pin::Pin}; use s2n_quic_core::{inet::datagram, io::tx, path}; #[cfg(any(s2n_quic_platform_socket_msg, s2n_quic_platform_socket_mmsg))] @@ -25,12 +25,17 @@ pub mod default { } } +pub type Storage = Pin]>>; + /// An abstract message that can be sent and received on a network -pub trait Message { +pub trait Message: 'static + Copy { type Handle: path::Handle; const SUPPORTS_GSO: bool; + /// Allocates `entries` messages, each with `payload_len` bytes + fn alloc(entries: u32, payload_len: u32, offset: usize) -> Storage; + /// Returns the length of the payload fn payload_len(&self) -> usize; @@ -50,6 +55,13 @@ pub trait Message { /// This should used in scenarios where the data pointers are the same. fn replicate_fields_from(&mut self, other: &Self); + /// Validates that the `source` message can be replicated to `dest`. + /// + /// # Panics + /// + /// This panics when the messages cannot be replicated + fn validate_replication(source: &Self, dest: &Self); + /// Returns a mutable pointer for the message payload fn payload_ptr_mut(&mut self) -> *mut u8; diff --git a/quic/s2n-quic-platform/src/message/mmsg.rs b/quic/s2n-quic-platform/src/message/mmsg.rs index 06f6da272..f4f68c7f5 100644 --- a/quic/s2n-quic-platform/src/message/mmsg.rs +++ b/quic/s2n-quic-platform/src/message/mmsg.rs @@ -18,6 +18,16 @@ impl MessageTrait for mmsghdr { const SUPPORTS_GSO: bool = libc::msghdr::SUPPORTS_GSO; + #[inline] + fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage { + unsafe { + msg::alloc(entries, payload_len, offset, |mmsghdr: &mut mmsghdr| { + mmsghdr.msg_len = payload_len as _; + &mut mmsghdr.msg_hdr + }) + } + } + #[inline] fn payload_len(&self) -> usize { self.msg_len as usize @@ -57,6 +67,11 @@ impl MessageTrait for mmsghdr { self.msg_hdr.replicate_fields_from(&other.msg_hdr) } + #[inline] + fn validate_replication(source: &Self, dest: &Self) { + libc::msghdr::validate_replication(&source.msg_hdr, &dest.msg_hdr) + } + #[inline] fn rx_read( &mut self, diff --git a/quic/s2n-quic-platform/src/message/msg.rs b/quic/s2n-quic-platform/src/message/msg.rs index 516e9c303..aad02e94f 100644 --- a/quic/s2n-quic-platform/src/message/msg.rs +++ b/quic/s2n-quic-platform/src/message/msg.rs @@ -4,8 +4,11 @@ use crate::message::{cmsg, cmsg::Encoder, Message as MessageTrait}; use alloc::vec::Vec; use core::{ - mem::{size_of, zeroed}, + alloc::Layout, + cell::UnsafeCell, + mem::{size_of, size_of_val, zeroed}, pin::Pin, + ptr::NonNull, }; use libc::{c_void, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6}; use s2n_quic_core::{ @@ -56,6 +59,11 @@ impl MessageTrait for msghdr { const SUPPORTS_GSO: bool = cfg!(s2n_quic_platform_gso); + #[inline] + fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage { + unsafe { alloc(entries, payload_len, offset, |msg| msg) } + } + #[inline] fn payload_len(&self) -> usize { debug_assert!(!self.msg_iov.is_null()); @@ -151,6 +159,13 @@ impl MessageTrait for msghdr { } } + #[inline] + fn validate_replication(source: &Self, dest: &Self) { + assert_eq!(source.msg_name, dest.msg_name); + assert_eq!(source.msg_iov, dest.msg_iov); + assert_eq!(source.msg_control, dest.msg_control); + } + #[inline] fn rx_read( &mut self, @@ -205,6 +220,94 @@ impl MessageTrait for msghdr { } } +#[inline] +pub(super) unsafe fn alloc &mut msghdr>( + entries: u32, + payload_len: u32, + offset: usize, + on_entry: F, +) -> super::Storage { + let (layout, entry_offset, header_offset, payload_offset) = + layout::(entries, payload_len, offset); + + let ptr = alloc::alloc::alloc_zeroed(layout); + + let end_pointer = ptr.add(layout.size()); + + let ptr = NonNull::new(ptr).expect("could not allocate socket message ring"); + + { + let mut entry_ptr = ptr.as_ptr().add(entry_offset) as *mut UnsafeCell; + let mut header_ptr = ptr.as_ptr().add(header_offset) as *mut UnsafeCell
; + let mut payload_ptr = ptr.as_ptr().add(payload_offset) as *mut UnsafeCell; + for _ in 0..entries { + let entry = on_entry((*entry_ptr).get_mut()); + (*header_ptr) + .get_mut() + .update(entry, &*payload_ptr, payload_len); + + entry_ptr = entry_ptr.add(1); + debug_assert!(end_pointer >= entry_ptr as *mut u8); + header_ptr = header_ptr.add(1); + debug_assert!(end_pointer >= header_ptr as *mut u8); + payload_ptr = payload_ptr.add(payload_len as _); + debug_assert!(end_pointer >= payload_ptr as *mut u8); + } + + let primary = ptr.as_ptr().add(entry_offset) as *mut T; + let secondary = primary.add(entries as _); + debug_assert!(end_pointer >= secondary.add(entries as _) as *mut u8); + core::ptr::copy_nonoverlapping(primary, secondary, entries as _); + } + + let slice = core::slice::from_raw_parts_mut(ptr.as_ptr() as *mut UnsafeCell, layout.size()); + Box::from_raw(slice).into() +} + +fn layout( + entries: u32, + payload_len: u32, + offset: usize, +) -> (Layout, usize, usize, usize) { + let cursor = Layout::array::>(offset).unwrap(); + let headers = Layout::array::>(entries as _).unwrap(); + let payloads = + Layout::array::>(entries as usize * payload_len as usize).unwrap(); + let entries = Layout::array::>((entries * 2) as usize).unwrap(); + let (layout, entry_offset) = cursor.extend(entries).unwrap(); + let (layout, header_offset) = layout.extend(headers).unwrap(); + let (layout, payload_offset) = layout.extend(payloads).unwrap(); + (layout, entry_offset, header_offset, payload_offset) +} + +#[repr(C)] +struct Header { + pub iovec: Aligned, + pub msg_name: Aligned, + pub cmsg: Aligned<[u8; cmsg::MAX_LEN]>, +} + +#[repr(C, align(8))] +struct Aligned(UnsafeCell); + +impl Header { + unsafe fn update(&mut self, entry: &mut msghdr, payload: &UnsafeCell, payload_len: u32) { + let iovec = self.iovec.0.get_mut(); + + iovec.iov_base = payload.get() as *mut _; + iovec.iov_len = payload_len as _; + + let entry = &mut *entry; + + entry.msg_name = self.msg_name.0.get() as *mut _; + entry.msg_namelen = size_of_val(&self.msg_name) as _; + entry.msg_iov = self.iovec.0.get(); + entry.msg_iovlen = 1; + entry.msg_control = self.cmsg.0.get() as *mut _; + entry.msg_controllen = cmsg::MAX_LEN as _; + } +} + pub struct Ring { pub(crate) messages: Vec, pub(crate) storage: Storage, diff --git a/quic/s2n-quic-platform/src/message/simple.rs b/quic/s2n-quic-platform/src/message/simple.rs index 08865db5d..52e23ba15 100644 --- a/quic/s2n-quic-platform/src/message/simple.rs +++ b/quic/s2n-quic-platform/src/message/simple.rs @@ -3,7 +3,7 @@ use crate::message::Message as MessageTrait; use alloc::vec::Vec; -use core::pin::Pin; +use core::{alloc::Layout, cell::UnsafeCell, pin::Pin, ptr::NonNull}; use s2n_quic_core::{ inet::{datagram, ExplicitCongestionNotification, SocketAddress}, io::tx, @@ -41,6 +41,11 @@ pub type Handle = path::Tuple; impl MessageTrait for Message { type Handle = Handle; + #[inline] + fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage { + unsafe { alloc(entries, payload_len, offset) } + } + const SUPPORTS_GSO: bool = false; fn payload_len(&self) -> usize { @@ -70,6 +75,11 @@ impl MessageTrait for Message { self.payload_len = other.payload_len; } + #[inline] + fn validate_replication(source: &Self, dest: &Self) { + assert_eq!(source.payload_ptr, dest.payload_ptr); + } + #[inline] fn rx_read( &mut self, @@ -117,6 +127,50 @@ impl MessageTrait for Message { } } +#[inline] +unsafe fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage { + let (layout, entry_offset, payload_offset) = layout(entries, payload_len, offset); + + let ptr = alloc::alloc::alloc_zeroed(layout); + + let end_pointer = ptr.add(layout.size()); + + let ptr = NonNull::new(ptr).expect("could not allocate socket message ring"); + + { + let mut entry_ptr = ptr.as_ptr().add(entry_offset) as *mut UnsafeCell; + let mut payload_ptr = ptr.as_ptr().add(payload_offset) as *mut UnsafeCell; + for _ in 0..entries { + let entry = (*entry_ptr).get_mut(); + entry.payload_ptr = (*payload_ptr).get(); + entry.payload_len = payload_len as _; + + entry_ptr = entry_ptr.add(1); + debug_assert!(end_pointer >= entry_ptr as *mut u8); + payload_ptr = payload_ptr.add(payload_len as _); + debug_assert!(end_pointer >= payload_ptr as *mut u8); + } + + let primary = ptr.as_ptr().add(entry_offset) as *mut Message; + let secondary = primary.add(entries as _); + debug_assert!(end_pointer >= secondary.add(entries as _) as *mut u8); + core::ptr::copy_nonoverlapping(primary, secondary, entries as _); + } + + let slice = core::slice::from_raw_parts_mut(ptr.as_ptr() as *mut UnsafeCell, layout.size()); + Box::from_raw(slice).into() +} + +fn layout(entries: u32, payload_len: u32, offset: usize) -> (Layout, usize, usize) { + let cursor = Layout::array::>(offset).unwrap(); + let payloads = + Layout::array::>(entries as usize * payload_len as usize).unwrap(); + let entries = Layout::array::>((entries * 2) as usize).unwrap(); + let (layout, entry_offset) = cursor.extend(entries).unwrap(); + let (layout, payload_offset) = layout.extend(payloads).unwrap(); + (layout, entry_offset, payload_offset) +} + pub struct Ring { messages: Vec, diff --git a/quic/s2n-quic-platform/src/socket.rs b/quic/s2n-quic-platform/src/socket.rs index 25d0d4fba..61e906071 100644 --- a/quic/s2n-quic-platform/src/socket.rs +++ b/quic/s2n-quic-platform/src/socket.rs @@ -3,12 +3,11 @@ use cfg_if::cfg_if; -#[cfg(s2n_quic_platform_socket_msg)] -pub mod msg; - #[cfg(s2n_quic_platform_socket_mmsg)] pub mod mmsg; - +#[cfg(s2n_quic_platform_socket_msg)] +pub mod msg; +pub mod ring; pub mod std; cfg_if! { diff --git a/quic/s2n-quic-platform/src/socket/ring.rs b/quic/s2n-quic-platform/src/socket/ring.rs new file mode 100644 index 000000000..031de311d --- /dev/null +++ b/quic/s2n-quic-platform/src/socket/ring.rs @@ -0,0 +1,307 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::message::{self, Message}; +use alloc::sync::Arc; +use core::{ + mem::size_of, + ptr::NonNull, + sync::atomic::AtomicU32, + task::{Context, Poll}, +}; +use s2n_quic_core::sync::{ + atomic_waker, + cursor::{self, Cursor}, + CachePadded, +}; + +const CURSOR_SIZE: usize = size_of::>(); +const PRODUCER_OFFSET: usize = 0; +const CONSUMER_OFFSET: usize = CURSOR_SIZE; +const DATA_OFFSET: usize = CURSOR_SIZE * 2; + +/// Creates a pair of rings for a given message type +pub fn pair(entries: u32, payload_len: u32) -> (Producer, Consumer) { + let storage = T::alloc(entries, payload_len, DATA_OFFSET); + + let storage = Arc::new(storage); + let ptr = NonNull::new(storage.as_ref()[0].get()).unwrap(); + + let wakers = atomic_waker::pair(); + + let consumer = Consumer { + cursor: unsafe { builder(ptr, entries).build_consumer() }, + wakers: wakers.0, + storage: storage.clone(), + }; + + let producer = Producer { + cursor: unsafe { builder(ptr, entries).build_producer() }, + wakers: wakers.1, + storage, + }; + + (producer, consumer) +} + +/// A consumer ring for messages +pub struct Consumer { + cursor: Cursor, + wakers: atomic_waker::Handle, + #[allow(dead_code)] + storage: Arc, +} + +/// Safety: Storage is synchronized with the Cursor +unsafe impl Send for Consumer {} +/// Safety: Storage is synchronized with the Cursor +unsafe impl Sync for Consumer {} + +impl Consumer { + /// Acquires ready-to-consume messages from the producer + #[inline] + pub fn acquire(&mut self, watermark: u32) -> u32 { + self.cursor.acquire_consumer(watermark) + } + + /// Polls ready-to-consume messages from the producer + #[inline] + pub fn poll_acquire(&mut self, watermark: u32, cx: &mut Context) -> Poll { + macro_rules! try_acquire { + () => {{ + let count = self.acquire(watermark); + + if count > 0 { + return Poll::Ready(count); + } + }}; + } + + try_acquire!(); + + self.wakers.register(cx.waker()); + + try_acquire!(); + + Poll::Pending + } + + /// Releases consumed messages to the producer + #[inline] + pub fn release(&mut self, len: u32) { + self.cursor.release_consumer(len); + + self.wakers.wake(); + } + + /// Returns the currently acquired messages + #[inline] + pub fn data(&mut self) -> &mut [T] { + let idx = self.cursor.cached_consumer(); + let len = self.cursor.cached_consumer_len(); + let ptr = self.cursor.data_ptr(); + unsafe { + let ptr = ptr.as_ptr().add(idx as _); + core::slice::from_raw_parts_mut(ptr, len as _) + } + } + + /// Returns true if the producer is not closed + #[inline] + pub fn is_open(&self) -> bool { + self.wakers.is_open() + } +} + +/// A producer ring for messages +pub struct Producer { + cursor: Cursor, + wakers: atomic_waker::Handle, + #[allow(dead_code)] + storage: Arc, +} + +/// Safety: Storage is synchronized with the Cursor +unsafe impl Send for Producer {} +/// Safety: Storage is synchronized with the Cursor +unsafe impl Sync for Producer {} + +impl Producer { + /// Acquires capacity for sending messages to the consumer + #[inline] + pub fn acquire(&mut self, watermark: u32) -> u32 { + self.cursor.acquire_producer(watermark) + } + + /// Polls capacity for sending messages to the consumer + #[inline] + pub fn poll_acquire(&mut self, watermark: u32, cx: &mut Context) -> Poll { + macro_rules! try_acquire { + () => {{ + let count = self.acquire(watermark); + + if count > 0 { + return Poll::Ready(count); + } + }}; + } + + try_acquire!(); + + self.wakers.register(cx.waker()); + + try_acquire!(); + + Poll::Pending + } + + /// Releases ready-to-consume messages to the consumer + #[inline] + pub fn release(&mut self, len: u32) { + if len == 0 { + return; + } + + debug_assert!(len <= self.cursor.cached_producer_len()); + + let idx = self.cursor.cached_producer(); + let size = self.cursor.capacity(); + + // replicate any written items to the secondary region + unsafe { + let replication_count = (size - idx).min(len); + + debug_assert_ne!(replication_count, 0); + + let ptr = self.cursor.data_ptr().as_ptr().add(idx as _); + + let primary = ptr; + let secondary = ptr.add(size as _); + + self.replicate(primary, secondary, replication_count as _); + } + + // if messages were also written to the secondary region, we need to copy them back to the + // primary region + if let Some(replication_count) = (idx + len).checked_sub(size).filter(|v| *v > 0) { + unsafe { + let ptr = self.cursor.data_ptr().as_ptr(); + + let primary = ptr; + let secondary = ptr.add(size as _); + + self.replicate(secondary, primary, replication_count as _); + } + } + + self.cursor.release_producer(len); + + self.wakers.wake(); + } + + /// Returns the empty messages for the producer + #[inline] + pub fn data(&mut self) -> &mut [T] { + let idx = self.cursor.cached_producer(); + let len = self.cursor.cached_producer_len(); + let ptr = self.cursor.data_ptr(); + unsafe { + let ptr = ptr.as_ptr().add(idx as _); + core::slice::from_raw_parts_mut(ptr, len as _) + } + } + + /// Returns true if the consumer is not closed + #[inline] + pub fn is_open(&self) -> bool { + self.wakers.is_open() + } + + /// Replicates messages from the primary to secondary memory regions + #[inline] + unsafe fn replicate(&self, primary: *mut T, secondary: *mut T, len: usize) { + debug_assert_ne!(len, 0); + + #[cfg(debug_assertions)] + { + let primary = core::slice::from_raw_parts(primary, len as _); + let secondary = core::slice::from_raw_parts(secondary, len as _); + for (primary, secondary) in primary.iter().zip(secondary) { + T::validate_replication(primary, secondary); + } + } + + core::ptr::copy_nonoverlapping(primary, secondary, len as _); + } +} + +#[inline] +unsafe fn builder(ptr: NonNull, size: u32) -> cursor::Builder { + let ptr = ptr.as_ptr(); + let producer = ptr.add(PRODUCER_OFFSET) as *mut _; + let producer = NonNull::new(producer).unwrap(); + let consumer = ptr.add(CONSUMER_OFFSET) as *mut _; + let consumer = NonNull::new(consumer).unwrap(); + let data = ptr.add(DATA_OFFSET) as *mut _; + let data = NonNull::new(data).unwrap(); + + cursor::Builder { + producer, + consumer, + data, + size, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bolero::check; + + macro_rules! replication_test { + ($name:ident, $msg:ty) => { + #[test] + fn $name() { + check!().with_type::>().for_each(|counts| { + let entries = 16; + + let (mut producer, mut consumer) = pair::<$msg>(entries, 100); + + let mut counter = 0; + + for count in counts.iter().copied() { + let count = producer.acquire(count); + + for entry in &mut producer.data()[..count as usize] { + unsafe { + entry.set_payload_len(counter); + } + counter += 1; + } + + producer.release(count); + + for idx in 0..entries { + let ptr = producer.cursor.data_ptr().as_ptr(); + unsafe { + let primary = &*ptr.add(idx as _); + let secondary = &*ptr.add((idx + entries) as _); + + assert_eq!(primary.payload_len(), secondary.payload_len()); + } + } + + let count = consumer.acquire(count); + consumer.release(count); + } + }); + } + }; + } + + replication_test!(simple_replication, crate::message::simple::Message); + #[cfg(s2n_quic_platform_socket_msg)] + replication_test!(msg_replication, crate::message::msg::Message); + #[cfg(s2n_quic_platform_socket_mmsg)] + replication_test!(mmsg_replication, crate::message::mmsg::Message); +}