Skip to content

Commit

Permalink
Encode infallible alignment errors in types
Browse files Browse the repository at this point in the history
Permit callers to prove at compile time that alignment errors are
unreachable for unaligned destination types. This permits them to
infallibly ignore this error condition.
  • Loading branch information
joshlf committed Sep 21, 2024
1 parent 553996e commit e9e640d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 18 deletions.
80 changes: 74 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ use core::error::Error;
#[cfg(all(not(zerocopy_core_error), any(feature = "std", test)))]
use std::error::Error;

use crate::{util::SendSyncPhantomData, KnownLayout, TryFromBytes};
use crate::{util::SendSyncPhantomData, KnownLayout, TryFromBytes, Unaligned};
#[cfg(doc)]
use crate::{FromBytes, Ref};

Expand Down Expand Up @@ -135,6 +135,29 @@ pub enum ConvertError<A, S, V> {
Validity(V),
}

impl<Src, Dst: ?Sized + Unaligned, S, V> ConvertError<AlignmentError<Src, Dst>, S, V> {
// TODO: Bikeshed this name
pub fn recall_aligned(self) -> ConvertError<Infallible, S, V> {
match self {
Self::Alignment(e) => ConvertError::Alignment(e.into_infallible()),
Self::Size(e) => ConvertError::Size(e),
Self::Validity(e) => ConvertError::Validity(e),
}
}
}

impl<Src, Dst: ?Sized + Unaligned, S, V> Into<ConvertError<Infallible, S, V>>
for ConvertError<AlignmentError<Src, Dst>, S, V>
{
fn into(self) -> ConvertError<Infallible, S, V> {
match self {
Self::Alignment(e) => ConvertError::Alignment(e.into_infallible()),
Self::Size(e) => ConvertError::Size(e),
Self::Validity(e) => ConvertError::Validity(e),
}
}
}

impl<A: fmt::Debug, S: fmt::Debug, V: fmt::Debug> fmt::Debug for ConvertError<A, S, V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -177,11 +200,20 @@ pub struct AlignmentError<Src, Dst: ?Sized> {
/// The source value involved in the conversion.
src: Src,
/// The inner destination type inolved in the conversion.
///
/// INVARIANT: An `AlignmentError` may only be constructed if `Dst`'s
/// alignment requirement is greater than one.
dst: SendSyncPhantomData<Dst>,
}

impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
pub(crate) fn new(src: Src) -> Self {
/// # Safety
///
/// The caller must ensure that `Dst`'s alignment requirement is greater
/// than one.
pub(crate) unsafe fn new_unchecked(src: Src) -> Self {
// INVARIANT: The caller guarantees that `Dst`'s alignment requirement
// is greater than one.
Self { src, dst: SendSyncPhantomData::default() }
}

Expand All @@ -192,6 +224,9 @@ impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
}

pub(crate) fn with_src<NewSrc>(self, new_src: NewSrc) -> AlignmentError<NewSrc, Dst> {
// INVARIANT: `with_src` doesn't change the type of `Dst`, so the
// invariant that `Dst`'s alignment requirement is greater than one is
// preserved.
AlignmentError { src: new_src, dst: SendSyncPhantomData::default() }
}

Expand Down Expand Up @@ -255,6 +290,28 @@ impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
}
}

impl<Src, Dst: ?Sized + Unaligned> AlignmentError<Src, Dst> {
fn into_infallible(self) -> Infallible {
// SAFETY: `AlignmentError`s can only be constructed when `Dst`'s
// alignment requirement is greater than one. In this block, `Dst:
// Unaligned`, which means that its alignment requirement is equal to
// one. Thus, it's not possible to reach here at runtime.
unsafe { core::hint::unreachable_unchecked() }
}
}

#[cfg(test)]
impl<Src, Dst> AlignmentError<Src, Dst> {
// A convenience constructor so that test code doesn't need to write
// `unsafe`.
fn new_checked(src: Src) -> AlignmentError<Src, Dst> {
assert_ne!(core::mem::align_of::<Dst>(), 1);
// SAFETY: The preceding assertion guarantees that `Dst`'s alignment
// requirement is greater than one.
unsafe { AlignmentError::new_unchecked(src) }
}
}

impl<Src, Dst: ?Sized> fmt::Debug for AlignmentError<Src, Dst> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -626,6 +683,17 @@ impl<Src, Dst: ?Sized> CastError<Src, Dst> {
}
}

impl<Src, Dst: ?Sized + Unaligned> Into<SizeError<Src, Dst>> for CastError<Src, Dst> {
fn into(self) -> SizeError<Src, Dst> {
match self {
#[allow(unreachable_code)]
Self::Alignment(e) => match e.into_infallible() {},
Self::Size(e) => e,
Self::Validity(i) => match i {},
}
}
}

/// The error type of fallible reference conversions.
///
/// Fallible reference conversions, like [`TryFromBytes::try_ref_from_bytes`]
Expand Down Expand Up @@ -818,7 +886,7 @@ mod tests {
let bytes = &aligned.bytes[1..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 1)\
Expand All @@ -829,7 +897,7 @@ mod tests {
let bytes = &aligned.bytes[2..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 2)\
Expand All @@ -840,7 +908,7 @@ mod tests {
let bytes = &aligned.bytes[3..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 1)\
Expand All @@ -851,7 +919,7 @@ mod tests {
let bytes = &aligned.bytes[4..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 4)\
Expand Down
10 changes: 7 additions & 3 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,8 @@ mod _transitions {
where
T: Sized,
{
if !crate::util::aligned_to::<_, T>(self.as_non_null()) {
return Err(AlignmentError::new(self));
if let Err(err) = crate::util::validate_aligned_to::<_, T>(self.as_non_null()) {
return Err(err.with_src(self));
}

// SAFETY: We just checked the alignment.
Expand Down Expand Up @@ -1204,7 +1204,11 @@ mod _casts {
let (elems, split_at) = match maybe_metadata {
Ok((elems, split_at)) => (elems, split_at),
Err(MetadataCastError::Alignment) => {
return Err(CastError::Alignment(AlignmentError::new(self)))
// SAFETY: Since `validate_cast_and_convert_metadata`
// returned an alignment error, `U` must have an alignment
// requirement greater than one.
let err = unsafe { AlignmentError::<_, U>::new_unchecked(self) };
return Err(CastError::Alignment(err));
}
Err(MetadataCastError::Size) => return Err(CastError::Size(SizeError::new(self))),
};
Expand Down
12 changes: 6 additions & 6 deletions src/ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ where
if bytes.len() != mem::size_of::<T>() {
return Err(SizeError::new(bytes).into());
}
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}

// SAFETY: We just validated size and alignment.
Expand All @@ -220,8 +220,8 @@ where
if bytes.len() < mem::size_of::<T>() {
return Err(SizeError::new(bytes).into());
}
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}
let (bytes, suffix) =
bytes.split_at(mem::size_of::<T>()).map_err(|b| SizeError::new(b).into())?;
Expand All @@ -243,8 +243,8 @@ where
return Err(SizeError::new(bytes).into());
};
let (prefix, bytes) = bytes.split_at(split_at).map_err(|b| SizeError::new(b).into())?;
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}
// SAFETY: Since `split_at` is defined as `bytes_len - size_of::<T>()`,
// the `bytes` which results from `let (prefix, bytes) =
Expand Down
13 changes: 10 additions & 3 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use core::{
};

use crate::{
error::AlignmentError,
pointer::invariant::{self, Invariants},
Unalign,
};
Expand Down Expand Up @@ -547,14 +548,20 @@ impl<T: ?Sized> AsAddress for *mut T {
}
}

/// Is `t` aligned to `align_of::<U>()`?
/// Validates that `t` is aligned to `align_of::<U>()`.
#[inline(always)]
pub(crate) fn aligned_to<T: AsAddress, U>(t: T) -> bool {
pub(crate) fn validate_aligned_to<T: AsAddress, U>(t: T) -> Result<(), AlignmentError<(), U>> {
// `mem::align_of::<U>()` is guaranteed to return a non-zero value, which in
// turn guarantees that this mod operation will not panic.
#[allow(clippy::arithmetic_side_effects)]
let remainder = t.addr() % mem::align_of::<U>();
remainder == 0
if remainder == 0 {
Ok(())
} else {
// SAFETY: We just confirmed that `t.addr() % align_of::<U>() != 0`.
// That's only possible if `align_of::<U>() > 1`.
Err(unsafe { AlignmentError::new_unchecked(()) })
}
}

/// Returns the bytes needed to pad `len` to the next multiple of `align`.
Expand Down

0 comments on commit e9e640d

Please sign in to comment.