Skip to content

Commit

Permalink
Extending std::simd improvements to BytesPerPixel::Six.
Browse files Browse the repository at this point in the history
  • Loading branch information
anforowicz committed Sep 25, 2023
1 parent 2e6b6ee commit 788633a
Showing 1 changed file with 176 additions and 71 deletions.
247 changes: 176 additions & 71 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ use core::convert::TryInto;

use crate::common::BytesPerPixel;

/// SIMD helpers for `fn unfilter`
/// SIMD helpers for `fn unfilter` for cases when auto-vectorization doesn't work as well:
/// - `BytesPerPixel::Three`, `BytesPerPixel::Six`
/// - `FilterType::Sub`, `FilterType::Avg`, `FilterType::Paeth`
///
/// TODO(https://github.com/rust-lang/rust/issues/86656): Stop gating this module behind the
/// "unstable" feature of the `png` crate. This should be possible once the "portable_simd"
/// feature of Rust gets stabilized.
/// "unstable" feature of the `png` crate. This should be possible once the "portable_simd" feature
/// of Rust gets stabilized.
#[cfg(feature = "unstable")]
mod simd {
use std::simd::{i16x4, u16x4, u8x4, SimdInt, SimdOrd, SimdPartialEq, SimdUint};
use std::simd::{
u8x4, u8x8, LaneCount, Simd, SimdInt, SimdOrd, SimdPartialEq, SimdUint, SupportedLaneCount,
};

/// This is an equivalent of the `PaethPredictor` function from
/// [the spec](http://www.libpng.org/pub/png/spec/1.2/PNG-Filters.html#Filter-type-4-Paeth)
Expand All @@ -22,7 +26,14 @@ mod simd {
/// - RGB => 4 lanes of `i16x4` contain R, G, B, and a ignored 4th value
///
/// The SIMD algorithm below is based on [`libpng`](https://github.com/glennrp/libpng/blob/f8e5fa92b0e37ab597616f554bee254157998227/intel/filter_sse2_intrinsics.c#L261-L280).
fn paeth_predictor(a: i16x4, b: i16x4, c: i16x4) -> i16x4 {
fn paeth_predictor<const N: usize>(
a: Simd<i16, N>,
b: Simd<i16, N>,
c: Simd<i16, N>,
) -> Simd<i16, N>
where
LaneCount<N>: SupportedLaneCount,
{
let pa = b - c; // (p-a) == (a+b-c - a) == (b-c)
let pb = a - c; // (p-b) == (a+b-c - b) == (a-c)
let pc = pa + pb; // (p-c) == (a+b-c - c) == (a+b-c-c) == (b-c)+(a-c)
Expand All @@ -48,31 +59,48 @@ mod simd {
u8x4::from_array([src[0], src[1], src[2], 0])
}

fn store3(src: u8x4, dest: &mut [u8]) {
dest[0..3].copy_from_slice(&src.to_array()[0..3])
fn load6(src: &[u8]) -> u8x8 {
u8x8::from_array([src[0], src[1], src[2], src[3], src[4], src[5], 0, 0])
}

fn store3(simd: u8x4, dest: &mut [u8]) {
dest[0..3].copy_from_slice(&simd.to_array()[0..3])
}

fn store6(simd: u8x8, dest: &mut [u8]) {
dest[0..6].copy_from_slice(&simd.to_array()[0..6])
}

/// Unified abstraction over `SubState`, `AvgState`, and `PaethState`.
trait UnfilterState: Default {
trait UnfilterState<const N: usize>: Default
where
LaneCount<N>: SupportedLaneCount,
{
/// Considers the next pair of pixels (`prev` from a previous row, `curr` from the current
/// row) and mutates `curr` to undo filtering.
///
/// Note that there may be less SIMD lanes than subpixel values. The implementation
/// will always utilize all the lanes, but when the caller processes RGB8 pixels, then
/// they would in practice ignore the 4th SIMD lane.
fn step(&mut self, prev: u8x4, curr: &mut u8x4);
fn step(&mut self, prev: Simd<u8, N>, curr: &mut Simd<u8, N>);
}

/// Memory of previous pixels (as needed to unfilter `FilterType::Sub`).
/// See also https://www.w3.org/TR/png/#filter-byte-positions
#[derive(Default)]
struct SubState {
struct SubState<const N: usize>
where
LaneCount<N>: SupportedLaneCount,
{
/// Previous pixel in the current row.
a: u8x4,
a: Simd<u8, N>,
}

impl UnfilterState for SubState {
fn step(&mut self, _prev: u8x4, curr: &mut u8x4) {
impl<const N: usize> UnfilterState<N> for SubState<N>
where
LaneCount<N>: SupportedLaneCount,
{
fn step(&mut self, _prev: Simd<u8, N>, curr: &mut Simd<u8, N>) {
// Calculating the new value of the current pixel.
*curr += self.a;

Expand All @@ -84,19 +112,25 @@ mod simd {
/// Memory of previous pixels (as needed to unfilter `FilterType::Avg`).
/// See also https://www.w3.org/TR/png/#filter-byte-positions
#[derive(Default)]
struct AvgState {
struct AvgState<const N: usize>
where
LaneCount<N>: SupportedLaneCount,
{
/// Previous pixel in the current row.
a: u16x4,
a: Simd<u16, N>,
}

impl UnfilterState for AvgState {
fn step(&mut self, prev: u8x4, curr: &mut u8x4) {
impl<const N: usize> UnfilterState<N> for AvgState<N>
where
LaneCount<N>: SupportedLaneCount,
{
fn step(&mut self, prev: Simd<u8, N>, curr: &mut Simd<u8, N>) {
// Storing the inputs.
let b = prev.cast::<u16>();
let x = curr;

// Calculating the new value of the current pixel.
let one = u16x4::splat(1);
let one = Simd::<u16, N>::splat(1);
let avg = (self.a + b) >> one;
*x += avg.cast::<u8>();

Expand All @@ -108,16 +142,22 @@ mod simd {
/// Memory of previous pixels (as needed to unfilter `FilterType::Paeth`).
/// See also https://www.w3.org/TR/png/#filter-byte-positions
#[derive(Default)]
struct PaethState {
struct PaethState<const N: usize>
where
LaneCount<N>: SupportedLaneCount,
{
/// Previous pixel in the previous row.
c: i16x4,
c: Simd<i16, N>,

/// Previous pixel in the current row.
a: i16x4,
a: Simd<i16, N>,
}

impl UnfilterState for PaethState {
fn step(&mut self, prev: u8x4, curr: &mut u8x4) {
impl<const N: usize> UnfilterState<N> for PaethState<N>
where
LaneCount<N>: SupportedLaneCount,
{
fn step(&mut self, prev: Simd<u8, N>, curr: &mut Simd<u8, N>) {
// Storing the inputs.
let b = prev.cast::<i16>();
let x = curr;
Expand All @@ -132,7 +172,7 @@ mod simd {
}
}

fn unfilter3<T: UnfilterState>(mut prev_row: &[u8], mut curr_row: &mut [u8]) {
fn unfilter3<T: UnfilterState<4>>(mut prev_row: &[u8], mut curr_row: &mut [u8]) {
debug_assert_eq!(prev_row.len(), curr_row.len());
debug_assert_eq!(prev_row.len() % 3, 0);

Expand Down Expand Up @@ -164,17 +204,56 @@ mod simd {

/// Undoes `FilterType::Sub` for `BytesPerPixel::Three`.
pub fn unfilter_sub3(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter3::<SubState>(prev_row, curr_row);
unfilter3::<SubState<4>>(prev_row, curr_row);
}

/// Undoes `FilterType::Avg` for `BytesPerPixel::Three`.
pub fn unfilter_avg3(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter3::<AvgState>(prev_row, curr_row);
unfilter3::<AvgState<4>>(prev_row, curr_row);
}

/// Undoes `FilterType::Paeth` for `BytesPerPixel::Three`.
pub fn unfilter_paeth3(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter3::<PaethState>(prev_row, curr_row);
unfilter3::<PaethState<4>>(prev_row, curr_row);
}

fn unfilter6<T: UnfilterState<8>>(mut prev_row: &[u8], mut curr_row: &mut [u8]) {
debug_assert_eq!(prev_row.len(), curr_row.len());
debug_assert_eq!(prev_row.len() % 6, 0);

let mut state = T::default();
while prev_row.len() >= 8 {
// `u8x8` requires working with `[u8;8]`, but we can just load and ignore the first
// 2 bytes of the next pixel. This optimization technique mimics the algorithm found
// in
// https://github.com/glennrp/libpng/blob/f8e5fa92b0e37ab597616f554bee254157998227/intel/filter_sse2_intrinsics.c#L130-L131
let prev_simd = u8x8::from_slice(prev_row);
let mut curr_simd = u8x8::from_slice(curr_row);
state.step(prev_simd, &mut curr_simd);
// We can speculate that writing 8 bytes might be more efficient (just as with using
// `u8x8::from_slice` above), but we can't use that here, because we can't clobber the
// first bytes of the next pixel in the `curr_row`.
store6(curr_simd, curr_row);
prev_row = &prev_row[6..];
curr_row = &mut curr_row[6..];
}
// Can't use `u8x8::from_slice` for the last `[u8;6]`.
let prev_simd = load6(prev_row);
let mut curr_simd = load6(curr_row);
state.step(prev_simd, &mut curr_simd);
store6(curr_simd, curr_row);
}

pub fn unfilter_sub6(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter6::<SubState<8>>(prev_row, curr_row);
}

pub fn unfilter_avg6(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter6::<AvgState<8>>(prev_row, curr_row);
}

pub fn unfilter_paeth6(prev_row: &[u8], curr_row: &mut [u8]) {
unfilter6::<PaethState<8>>(prev_row, curr_row);
}
}

Expand Down Expand Up @@ -433,18 +512,24 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Six => {
let mut prev = [0; 6];
for chunk in current.chunks_exact_mut(6) {
let new_chunk = [
chunk[0].wrapping_add(prev[0]),
chunk[1].wrapping_add(prev[1]),
chunk[2].wrapping_add(prev[2]),
chunk[3].wrapping_add(prev[3]),
chunk[4].wrapping_add(prev[4]),
chunk[5].wrapping_add(prev[5]),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
prev = new_chunk;
#[cfg(feature = "unstable")]
simd::unfilter_sub6(previous, current);

#[cfg(not(feature = "unstable"))]
{
let mut prev = [0; 6];
for chunk in current.chunks_exact_mut(6) {
let new_chunk = [
chunk[0].wrapping_add(prev[0]),
chunk[1].wrapping_add(prev[1]),
chunk[2].wrapping_add(prev[2]),
chunk[3].wrapping_add(prev[3]),
chunk[4].wrapping_add(prev[4]),
chunk[5].wrapping_add(prev[5]),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
prev = new_chunk;
}
}
}
BytesPerPixel::Eight => {
Expand Down Expand Up @@ -524,18 +609,25 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Six => {
let mut lprev = [0; 6];
for (chunk, above) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6)) {
let new_chunk = [
chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
chunk[3].wrapping_add(((above[3] as u16 + lprev[3] as u16) / 2) as u8),
chunk[4].wrapping_add(((above[4] as u16 + lprev[4] as u16) / 2) as u8),
chunk[5].wrapping_add(((above[5] as u16 + lprev[5] as u16) / 2) as u8),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
lprev = new_chunk;
#[cfg(feature = "unstable")]
simd::unfilter_avg6(previous, current);

#[cfg(not(feature = "unstable"))]
{
let mut lprev = [0; 6];
for (chunk, above) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
{
let new_chunk = [
chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
chunk[3].wrapping_add(((above[3] as u16 + lprev[3] as u16) / 2) as u8),
chunk[4].wrapping_add(((above[4] as u16 + lprev[4] as u16) / 2) as u8),
chunk[5].wrapping_add(((above[5] as u16 + lprev[5] as u16) / 2) as u8),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
lprev = new_chunk;
}
}
}
BytesPerPixel::Eight => {
Expand Down Expand Up @@ -638,27 +730,40 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Six => {
let mut a_bpp = [0; 6];
let mut c_bpp = [0; 6];
for (chunk, b_bpp) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
#[cfg(feature = "unstable")]
simd::unfilter_paeth6(previous, current);

#[cfg(not(feature = "unstable"))]
{
let new_chunk = [
chunk[0]
.wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
chunk[1]
.wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
chunk[2]
.wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
chunk[3]
.wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
chunk[4]
.wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])),
chunk[5]
.wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
let mut a_bpp = [0; 6];
let mut c_bpp = [0; 6];
for (chunk, b_bpp) in
current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
{
let new_chunk = [
chunk[0].wrapping_add(filter_paeth_decode(
a_bpp[0], b_bpp[0], c_bpp[0],
)),
chunk[1].wrapping_add(filter_paeth_decode(
a_bpp[1], b_bpp[1], c_bpp[1],
)),
chunk[2].wrapping_add(filter_paeth_decode(
a_bpp[2], b_bpp[2], c_bpp[2],
)),
chunk[3].wrapping_add(filter_paeth_decode(
a_bpp[3], b_bpp[3], c_bpp[3],
)),
chunk[4].wrapping_add(filter_paeth_decode(
a_bpp[4], b_bpp[4], c_bpp[4],
)),
chunk[5].wrapping_add(filter_paeth_decode(
a_bpp[5], b_bpp[5], c_bpp[5],
)),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
}
}
BytesPerPixel::Eight => {
Expand Down

0 comments on commit 788633a

Please sign in to comment.