Skip to content

Commit

Permalink
Merge pull request #1410 from rust-ndarray/aliasing-checks
Browse files Browse the repository at this point in the history
Allow aliasing in ArrayView::from_shape
  • Loading branch information
bluss authored Aug 3, 2024
2 parents e578d58 + 516a504 commit e9e8c9d
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 53 deletions.
123 changes: 76 additions & 47 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
}
}

/// Select how aliasing is checked
///
/// For owned or mutable data:
///
/// The strides must not allow any element to be referenced by two different indices.
///
#[derive(Copy, Clone, PartialEq)]
pub(crate) enum CanIndexCheckMode
{
/// Owned or mutable: No aliasing
OwnedMutable,
/// Aliasing
ReadOnly,
}

/// Checks whether the given data and dimension meet the invariants of the
/// `ArrayBase` type, assuming the strides are created using
/// `dim.default_strides()` or `dim.fortran_strides()`.
Expand All @@ -125,12 +140,13 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
/// `A` and in units of bytes between the least address and greatest address
/// accessible by moving along all axes does not exceed `isize::MAX`.
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(
data: &[A], dim: &D, strides: &Strides<D>,
data: &[A], dim: &D, strides: &Strides<D>, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
if let Strides::Custom(strides) = strides {
can_index_slice(data, dim, strides)
can_index_slice(data, dim, strides, mode)
} else {
// contiguous shapes: never aliasing, mode does not matter
can_index_slice_not_custom(data.len(), dim)
}
}
Expand Down Expand Up @@ -239,15 +255,19 @@ where D: Dimension
/// allocation. (In other words, the pointer to the first element of the array
/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that
/// negative strides are correctly handled.)
pub(crate) fn can_index_slice<A, D: Dimension>(data: &[A], dim: &D, strides: &D) -> Result<(), ShapeError>
///
/// Note, condition (4) is guaranteed to be checked last
pub(crate) fn can_index_slice<A, D: Dimension>(
data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
// Check conditions 1 and 2 and calculate `max_offset`.
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
can_index_slice_impl(max_offset, data.len(), dim, strides)
can_index_slice_impl(max_offset, data.len(), dim, strides, mode)
}

fn can_index_slice_impl<D: Dimension>(
max_offset: usize, data_len: usize, dim: &D, strides: &D,
max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
// Check condition 3.
Expand All @@ -260,7 +280,7 @@ fn can_index_slice_impl<D: Dimension>(
}

// Check condition 4.
if !is_empty && dim_stride_overlap(dim, strides) {
if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) {
return Err(from_kind(ErrorKind::Unsupported));
}

Expand Down Expand Up @@ -782,6 +802,7 @@ mod test
slice_min_max,
slices_intersect,
solve_linear_diophantine_eq,
CanIndexCheckMode,
IntoDimension,
};
use crate::error::{from_kind, ErrorKind};
Expand All @@ -796,11 +817,11 @@ mod test
let v: alloc::vec::Vec<_> = (0..12).collect();
let dim = (2, 3, 2).into_dimension();
let strides = (1, 2, 6).into_dimension();
assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());

let strides = (2, 4, 12).into_dimension();
assert_eq!(
super::can_index_slice(&v, &dim, &strides),
super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable),
Err(from_kind(ErrorKind::OutOfBounds))
);
}
Expand Down Expand Up @@ -848,71 +869,79 @@ mod test
#[test]
fn can_index_slice_ix0()
{
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0()).unwrap();
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0()).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap();
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err();
}

#[test]
fn can_index_slice_ix1()
{
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0)).unwrap_err();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0), mode).unwrap_err();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap();
}

#[test]
fn can_index_slice_ix2()
{
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap();
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err();

// aliasing strides: ok when readonly
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err();
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap();
}

#[test]
fn can_index_slice_ix3()
{
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap();
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap();
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap();
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap();
}

#[test]
fn can_index_slice_zero_size_elem()
{
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap();
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap();

// These might seem okay because the element type is zero-sized, but
// there could be a zero-sized type such that the number of instances
// in existence are carefully controlled.
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err();
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err();

can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();

// This case would be probably be sound, but that's not entirely clear
// and it's not worth the special case code.
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
}

quickcheck! {
Expand All @@ -923,8 +952,8 @@ mod test
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use num_traits::{One, Zero};
use std::mem;
use std::mem::MaybeUninit;

use crate::dimension;
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::dimension::{self, CanIndexCheckMode};
use crate::error::{self, ShapeError};
use crate::extension::nonnull::nonnull_from_vec_data;
use crate::imp_prelude::*;
Expand Down Expand Up @@ -466,7 +466,7 @@ where
{
let dim = shape.dim;
let is_custom = shape.strides.is_custom();
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?;
if !is_custom && dim.size() != v.len() {
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
}
Expand Down Expand Up @@ -510,7 +510,7 @@ where
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self
{
// debug check for issues that indicates wrong use of this constructor
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());

let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides));
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
Expand Down
6 changes: 3 additions & 3 deletions src/impl_views/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

use std::ptr::NonNull;

use crate::dimension;
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::dimension::{self, CanIndexCheckMode};
use crate::error::ShapeError;
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
use crate::imp_prelude::*;
Expand Down Expand Up @@ -54,7 +54,7 @@ where D: Dimension
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError>
{
let dim = shape.dim;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe {
Ok(Self::new_(
Expand Down Expand Up @@ -157,7 +157,7 @@ where D: Dimension
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError>
{
let dim = shape.dim;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe {
Ok(Self::new_(
Expand Down
17 changes: 17 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use defmac::defmac;
use itertools::{zip, Itertools};
use ndarray::indices;
use ndarray::prelude::*;
use ndarray::ErrorKind;
use ndarray::{arr3, rcarr2};
use ndarray::{Slice, SliceInfo, SliceInfoElem};
use num_complex::Complex;
Expand Down Expand Up @@ -2060,6 +2061,22 @@ fn test_view_from_shape()
assert_eq!(a, answer);
}

#[test]
fn test_view_from_shape_allow_overlap()
{
let data = [0, 1, 2];
let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap();
assert_eq!(view, aview2(&[data; 2]));
}

#[test]
fn test_view_mut_from_shape_deny_overlap()
{
let mut data = [0, 1, 2];
let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data);
assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported));
}

#[test]
fn test_contiguous()
{
Expand Down

0 comments on commit e9e8c9d

Please sign in to comment.