Skip to content

Commit

Permalink
feat(s2n-quic-platform): add message ring using sync::Cursor
Browse files Browse the repository at this point in the history
  • Loading branch information
camshaft committed Jun 8, 2023
1 parent 56ef974 commit 901069f
Show file tree
Hide file tree
Showing 6 changed files with 498 additions and 8 deletions.
16 changes: 14 additions & 2 deletions quic/s2n-quic-platform/src/message.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::ffi::c_void;
use core::{cell::UnsafeCell, ffi::c_void, pin::Pin};
use s2n_quic_core::{inet::datagram, io::tx, path};

#[cfg(any(s2n_quic_platform_socket_msg, s2n_quic_platform_socket_mmsg))]
Expand All @@ -25,12 +25,17 @@ pub mod default {
}
}

pub type Storage = Pin<Box<[UnsafeCell<u8>]>>;

/// An abstract message that can be sent and received on a network
pub trait Message {
pub trait Message: 'static + Copy {
type Handle: path::Handle;

const SUPPORTS_GSO: bool;

/// Allocates `entries` messages, each with `payload_len` bytes
fn alloc(entries: u32, payload_len: u32, offset: usize) -> Storage;

/// Returns the length of the payload
fn payload_len(&self) -> usize;

Expand All @@ -50,6 +55,13 @@ pub trait Message {
/// This should used in scenarios where the data pointers are the same.
fn replicate_fields_from(&mut self, other: &Self);

/// Validates that the `source` message can be replicated to `dest`.
///
/// # Panics
///
/// This panics when the messages cannot be replicated
fn validate_replication(source: &Self, dest: &Self);

/// Returns a mutable pointer for the message payload
fn payload_ptr_mut(&mut self) -> *mut u8;

Expand Down
15 changes: 15 additions & 0 deletions quic/s2n-quic-platform/src/message/mmsg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ impl MessageTrait for mmsghdr {

const SUPPORTS_GSO: bool = libc::msghdr::SUPPORTS_GSO;

#[inline]
fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
unsafe {
msg::alloc(entries, payload_len, offset, |mmsghdr: &mut mmsghdr| {
mmsghdr.msg_len = payload_len as _;
&mut mmsghdr.msg_hdr
})
}
}

#[inline]
fn payload_len(&self) -> usize {
self.msg_len as usize
Expand Down Expand Up @@ -57,6 +67,11 @@ impl MessageTrait for mmsghdr {
self.msg_hdr.replicate_fields_from(&other.msg_hdr)
}

#[inline]
fn validate_replication(source: &Self, dest: &Self) {
libc::msghdr::validate_replication(&source.msg_hdr, &dest.msg_hdr)
}

#[inline]
fn rx_read(
&mut self,
Expand Down
105 changes: 104 additions & 1 deletion quic/s2n-quic-platform/src/message/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
use crate::message::{cmsg, cmsg::Encoder, Message as MessageTrait};
use alloc::vec::Vec;
use core::{
mem::{size_of, zeroed},
alloc::Layout,
cell::UnsafeCell,
mem::{size_of, size_of_val, zeroed},
pin::Pin,
ptr::NonNull,
};
use libc::{c_void, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6};
use s2n_quic_core::{
Expand Down Expand Up @@ -56,6 +59,11 @@ impl MessageTrait for msghdr {

const SUPPORTS_GSO: bool = cfg!(s2n_quic_platform_gso);

#[inline]
fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
unsafe { alloc(entries, payload_len, offset, |msg| msg) }
}

#[inline]
fn payload_len(&self) -> usize {
debug_assert!(!self.msg_iov.is_null());
Expand Down Expand Up @@ -151,6 +159,13 @@ impl MessageTrait for msghdr {
}
}

#[inline]
fn validate_replication(source: &Self, dest: &Self) {
assert_eq!(source.msg_name, dest.msg_name);
assert_eq!(source.msg_iov, dest.msg_iov);
assert_eq!(source.msg_control, dest.msg_control);
}

#[inline]
fn rx_read(
&mut self,
Expand Down Expand Up @@ -205,6 +220,94 @@ impl MessageTrait for msghdr {
}
}

#[inline]
pub(super) unsafe fn alloc<T: Copy + Sized, F: Fn(&mut T) -> &mut msghdr>(
entries: u32,
payload_len: u32,
offset: usize,
on_entry: F,
) -> super::Storage {
let (layout, entry_offset, header_offset, payload_offset) =
layout::<T>(entries, payload_len, offset);

let ptr = alloc::alloc::alloc_zeroed(layout);

let end_pointer = ptr.add(layout.size());

let ptr = NonNull::new(ptr).expect("could not allocate socket message ring");

{
let mut entry_ptr = ptr.as_ptr().add(entry_offset) as *mut UnsafeCell<T>;
let mut header_ptr = ptr.as_ptr().add(header_offset) as *mut UnsafeCell<Header>;
let mut payload_ptr = ptr.as_ptr().add(payload_offset) as *mut UnsafeCell<u8>;
for _ in 0..entries {
let entry = on_entry((*entry_ptr).get_mut());
(*header_ptr)
.get_mut()
.update(entry, &*payload_ptr, payload_len);

entry_ptr = entry_ptr.add(1);
debug_assert!(end_pointer >= entry_ptr as *mut u8);
header_ptr = header_ptr.add(1);
debug_assert!(end_pointer >= header_ptr as *mut u8);
payload_ptr = payload_ptr.add(payload_len as _);
debug_assert!(end_pointer >= payload_ptr as *mut u8);
}

let primary = ptr.as_ptr().add(entry_offset) as *mut T;
let secondary = primary.add(entries as _);
debug_assert!(end_pointer >= secondary.add(entries as _) as *mut u8);
core::ptr::copy_nonoverlapping(primary, secondary, entries as _);
}

let slice = core::slice::from_raw_parts_mut(ptr.as_ptr() as *mut UnsafeCell<u8>, layout.size());
Box::from_raw(slice).into()
}

fn layout<T: Copy + Sized>(
entries: u32,
payload_len: u32,
offset: usize,
) -> (Layout, usize, usize, usize) {
let cursor = Layout::array::<UnsafeCell<u8>>(offset).unwrap();
let headers = Layout::array::<UnsafeCell<Header>>(entries as _).unwrap();
let payloads =
Layout::array::<UnsafeCell<u8>>(entries as usize * payload_len as usize).unwrap();
let entries = Layout::array::<UnsafeCell<T>>((entries * 2) as usize).unwrap();
let (layout, entry_offset) = cursor.extend(entries).unwrap();
let (layout, header_offset) = layout.extend(headers).unwrap();
let (layout, payload_offset) = layout.extend(payloads).unwrap();
(layout, entry_offset, header_offset, payload_offset)
}

#[repr(C)]
struct Header {
pub iovec: Aligned<iovec>,
pub msg_name: Aligned<sockaddr_in6>,
pub cmsg: Aligned<[u8; cmsg::MAX_LEN]>,
}

#[repr(C, align(8))]
struct Aligned<T>(UnsafeCell<T>);

impl Header {
unsafe fn update(&mut self, entry: &mut msghdr, payload: &UnsafeCell<u8>, payload_len: u32) {
let iovec = self.iovec.0.get_mut();

iovec.iov_base = payload.get() as *mut _;
iovec.iov_len = payload_len as _;

let entry = &mut *entry;

entry.msg_name = self.msg_name.0.get() as *mut _;
entry.msg_namelen = size_of_val(&self.msg_name) as _;
entry.msg_iov = self.iovec.0.get();
entry.msg_iovlen = 1;
entry.msg_control = self.cmsg.0.get() as *mut _;
entry.msg_controllen = cmsg::MAX_LEN as _;
}
}

pub struct Ring<Payloads> {
pub(crate) messages: Vec<msghdr>,
pub(crate) storage: Storage<Payloads>,
Expand Down
56 changes: 55 additions & 1 deletion quic/s2n-quic-platform/src/message/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::message::Message as MessageTrait;
use alloc::vec::Vec;
use core::pin::Pin;
use core::{alloc::Layout, cell::UnsafeCell, pin::Pin, ptr::NonNull};
use s2n_quic_core::{
inet::{datagram, ExplicitCongestionNotification, SocketAddress},
io::tx,
Expand Down Expand Up @@ -41,6 +41,11 @@ pub type Handle = path::Tuple;
impl MessageTrait for Message {
type Handle = Handle;

#[inline]
fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
unsafe { alloc(entries, payload_len, offset) }
}

const SUPPORTS_GSO: bool = false;

fn payload_len(&self) -> usize {
Expand Down Expand Up @@ -70,6 +75,11 @@ impl MessageTrait for Message {
self.payload_len = other.payload_len;
}

#[inline]
fn validate_replication(source: &Self, dest: &Self) {
assert_eq!(source.payload_ptr, dest.payload_ptr);
}

#[inline]
fn rx_read(
&mut self,
Expand Down Expand Up @@ -117,6 +127,50 @@ impl MessageTrait for Message {
}
}

#[inline]
unsafe fn alloc(entries: u32, payload_len: u32, offset: usize) -> super::Storage {
let (layout, entry_offset, payload_offset) = layout(entries, payload_len, offset);

let ptr = alloc::alloc::alloc_zeroed(layout);

let end_pointer = ptr.add(layout.size());

let ptr = NonNull::new(ptr).expect("could not allocate socket message ring");

{
let mut entry_ptr = ptr.as_ptr().add(entry_offset) as *mut UnsafeCell<Message>;
let mut payload_ptr = ptr.as_ptr().add(payload_offset) as *mut UnsafeCell<u8>;
for _ in 0..entries {
let entry = (*entry_ptr).get_mut();
entry.payload_ptr = (*payload_ptr).get();
entry.payload_len = payload_len as _;

entry_ptr = entry_ptr.add(1);
debug_assert!(end_pointer >= entry_ptr as *mut u8);
payload_ptr = payload_ptr.add(payload_len as _);
debug_assert!(end_pointer >= payload_ptr as *mut u8);
}

let primary = ptr.as_ptr().add(entry_offset) as *mut Message;
let secondary = primary.add(entries as _);
debug_assert!(end_pointer >= secondary.add(entries as _) as *mut u8);
core::ptr::copy_nonoverlapping(primary, secondary, entries as _);
}

let slice = core::slice::from_raw_parts_mut(ptr.as_ptr() as *mut UnsafeCell<u8>, layout.size());
Box::from_raw(slice).into()
}

fn layout(entries: u32, payload_len: u32, offset: usize) -> (Layout, usize, usize) {
let cursor = Layout::array::<UnsafeCell<u8>>(offset).unwrap();
let payloads =
Layout::array::<UnsafeCell<u8>>(entries as usize * payload_len as usize).unwrap();
let entries = Layout::array::<UnsafeCell<Message>>((entries * 2) as usize).unwrap();
let (layout, entry_offset) = cursor.extend(entries).unwrap();
let (layout, payload_offset) = layout.extend(payloads).unwrap();
(layout, entry_offset, payload_offset)
}

pub struct Ring<Payloads> {
messages: Vec<Message>,

Expand Down
7 changes: 3 additions & 4 deletions quic/s2n-quic-platform/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

use cfg_if::cfg_if;

#[cfg(s2n_quic_platform_socket_msg)]
pub mod msg;

#[cfg(s2n_quic_platform_socket_mmsg)]
pub mod mmsg;

#[cfg(s2n_quic_platform_socket_msg)]
pub mod msg;
pub mod ring;
pub mod std;

cfg_if! {
Expand Down
Loading

0 comments on commit 901069f

Please sign in to comment.