Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve count vectorization: replace popcnt implementation with vector counting #4614

Merged
merged 27 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b1aacab
`count` vectorization: replace `popcnt` implementation with vector co…
AlexGuteniev Apr 21, 2024
6f134e8
Don't do extra reduce in the end
AlexGuteniev Apr 21, 2024
35d61ac
As the SSE branch has no masked tail, can reduce in a single pace
AlexGuteniev Apr 21, 2024
ce8d8a5
Reduce as infrequently as possible
AlexGuteniev Apr 22, 2024
af456db
missing range coverage
AlexGuteniev Apr 22, 2024
c232f62
test counting zeros
AlexGuteniev Apr 22, 2024
498408c
compare with expected in new coverage
AlexGuteniev Apr 22, 2024
48efb24
separate _Count_traits_N and reuse reduce
AlexGuteniev Apr 22, 2024
403ef98
formatting
AlexGuteniev Apr 22, 2024
983e93a
Stand away from overflows!
AlexGuteniev Apr 22, 2024
64b6fe6
sizes are bytes
AlexGuteniev Apr 22, 2024
85263d8
counting overflow better coverage
AlexGuteniev Apr 22, 2024
0a92efc
fewer ops to reduce
AlexGuteniev Apr 22, 2024
d816df0
Comments cleanup
AlexGuteniev Apr 22, 2024
dedc0cc
reduce 1-byte with `sad` instruction
AlexGuteniev Apr 22, 2024
015d4f7
Simplify `_Count_traits_8::_Reduce_avx()` by reusing `_Reduce_sse()`.
StephanTLavavej Apr 24, 2024
7e39a04
Fix `_Count_traits_4::_Max_count`.
StephanTLavavej Apr 25, 2024
8d4aab5
Add detailed comments explaining each `_Max_count`.
StephanTLavavej Apr 25, 2024
50d610a
For clarity, scope `__m128i _Count_vector` to each iteration of the S…
StephanTLavavej Apr 25, 2024
c06d4ff
For the AVX2 loop, scope `__m256i _Count_vector` separately for the m…
StephanTLavavej Apr 25, 2024
3f95815
Fix comment typo.
StephanTLavavej Apr 25, 2024
22b561b
Change test_count_zero() to test more lengths.
StephanTLavavej Apr 25, 2024
e587698
Restore popcnt approach
AlexGuteniev Apr 25, 2024
87749d2
Get my bounds back
AlexGuteniev Apr 25, 2024
2e1d6c7
typos
AlexGuteniev Apr 25, 2024
fc6dcbd
restore SSE4.2 comment
AlexGuteniev Apr 25, 2024
c5dd5c2
Clarify comments.
StephanTLavavej Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 191 additions & 34 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1734,8 +1734,6 @@ __declspec(noalias) _Min_max_d __stdcall __std_minmax_d(const void* const _First

namespace {
struct _Find_traits_1 {
static constexpr size_t _Shift = 0;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint8_t _Val) noexcept {
return _mm256_set1_epi8(_Val);
Expand All @@ -1756,8 +1754,6 @@ namespace {
};

struct _Find_traits_2 {
static constexpr size_t _Shift = 1;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint16_t _Val) noexcept {
return _mm256_set1_epi16(_Val);
Expand All @@ -1778,8 +1774,6 @@ namespace {
};

struct _Find_traits_4 {
static constexpr size_t _Shift = 2;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint32_t _Val) noexcept {
return _mm256_set1_epi32(_Val);
Expand All @@ -1800,8 +1794,6 @@ namespace {
};

struct _Find_traits_8 {
static constexpr size_t _Shift = 3;

#ifndef _M_ARM64EC
static __m256i _Set_avx(const uint64_t _Val) noexcept {
return _mm256_set1_epi64x(_Val);
Expand Down Expand Up @@ -1978,6 +1970,130 @@ namespace {
}
}

struct _Count_traits_8 : _Find_traits_8 {
#ifndef _M_ARM64EC
static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi64(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi64(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m128i _Lo64 = _mm256_extracti128_si256(_Val, 0);
const __m128i _Hi64 = _mm256_extracti128_si256(_Val, 1);
const __m128i _Rx8 = _mm_add_epi64(_Lo64, _Hi64);
return _Reduce_sse(_Rx8);
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
#ifdef _M_IX86
return static_cast<uint32_t>(_mm_cvtsi128_si32(_Val)) + static_cast<uint32_t>(_mm_extract_epi32(_Val, 2));
#else // ^^^ defined(_M_IX86) / defined(_M_X64) vvv
return _mm_cvtsi128_si64(_Val) + _mm_extract_epi64(_Val, 1);
#endif // ^^^ defined(_M_X64) ^^^
}
#endif // !_M_ARM64EC
};

struct _Count_traits_4 : _Find_traits_4 {
#ifndef _M_ARM64EC
// For AVX2, we use hadd_epi32 three times to combine pairs of 32-bit counters into 32-bit results.
// Therefore, _Max_count is 0x1FFF'FFFF, which is 0xFFFF'FFF8 when doubled three times; any more would overflow.

// For SSE4.2, we use hadd_epi32 twice. This would allow a larger limit,
// but it's simpler to use the smaller limit for both codepaths.

static constexpr size_t _Max_count = 0x1FFF'FFFF;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi32(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi32(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
constexpr auto _Shuf = _MM_SHUFFLE(3, 1, 2, 0); // Cross lane, to reduce further on low lane
const __m256i _Rx4 = _mm256_hadd_epi32(_Val, _mm256_setzero_si256()); // (0+1),(2+3),0,0 per lane
const __m256i _Rx5 = _mm256_permute4x64_epi64(_Rx4, _Shuf); // low lane (0+1),(2+3),(4+5),(6+7)
const __m256i _Rx6 = _mm256_hadd_epi32(_Rx5, _mm256_setzero_si256()); // (0+...+3),(4+...+7),0,0
const __m256i _Rx7 = _mm256_hadd_epi32(_Rx6, _mm256_setzero_si256()); // (0+...+7),0,0,0
return static_cast<uint32_t>(_mm_cvtsi128_si32(_mm256_castsi256_si128(_Rx7)));
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx4 = _mm_hadd_epi32(_Val, _mm_setzero_si128()); // (0+1),(2+3),0,0
const __m128i _Rx5 = _mm_hadd_epi32(_Rx4, _mm_setzero_si128()); // (0+...+3),0,0,0
return static_cast<uint32_t>(_mm_cvtsi128_si32(_Rx5));
}
#endif // !_M_ARM64EC
};

struct _Count_traits_2 : _Find_traits_2 {
#ifndef _M_ARM64EC
// For both AVX2 and SSE4.2, we use hadd_epi16 once to combine pairs of 16-bit counters into 16-bit results.
// Therefore, _Max_count is 0x7FFF, which is 0xFFFE when doubled; any more would overflow.

static constexpr size_t _Max_count = 0x7FFF;

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi16(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi16(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m256i _Rx2 = _mm256_hadd_epi16(_Val, _mm256_setzero_si256());
const __m256i _Rx3 = _mm256_unpacklo_epi16(_Rx2, _mm256_setzero_si256());
return _Count_traits_4::_Reduce_avx(_Rx3);
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx2 = _mm_hadd_epi16(_Val, _mm_setzero_si128());
const __m128i _Rx3 = _mm_unpacklo_epi16(_Rx2, _mm_setzero_si128());
return _Count_traits_4::_Reduce_sse(_Rx3);
}
#endif // !_M_ARM64EC
};

struct _Count_traits_1 : _Find_traits_1 {
#ifndef _M_ARM64EC
// For AVX2, _Max_portion_size below is _Max_count * 32 bytes, and we have 1-byte elements.
// We're using packed 8-bit counters, and 32 of those fit in 256 bits.

// For SSE4.2, _Max_portion_size below is _Max_count * 16 bytes, and we have 1-byte elements.
// We're using packed 8-bit counters, and 16 of those fit in 128 bits.

// For both codepaths, this is why _Max_count is the maximum unsigned 8-bit integer.
// (The reduction steps aren't the limiting factor here.)

static constexpr size_t _Max_count = 0xFF;

static __m256i _Sub_avx(const __m256i _Lhs, const __m256i _Rhs) noexcept {
return _mm256_sub_epi8(_Lhs, _Rhs);
}

static __m128i _Sub_sse(const __m128i _Lhs, const __m128i _Rhs) noexcept {
return _mm_sub_epi8(_Lhs, _Rhs);
}

static size_t _Reduce_avx(const __m256i _Val) noexcept {
const __m256i _Rx1 = _mm256_sad_epu8(_Val, _mm256_setzero_si256());
return _Count_traits_8::_Reduce_avx(_Rx1);
}

static size_t _Reduce_sse(const __m128i _Val) noexcept {
const __m128i _Rx1 = _mm_sad_epu8(_Val, _mm_setzero_si128());
return _Count_traits_8::_Reduce_sse(_Rx1);
}
#endif // !_M_ARM64EC
};

template <class _Traits, class _Ty>
__declspec(noalias) size_t
__stdcall __std_count_trivial_impl(const void* _First, const void* const _Last, const _Ty _Val) noexcept {
Expand All @@ -1986,47 +2102,88 @@ namespace {
#ifndef _M_ARM64EC
const size_t _Size_bytes = _Byte_length(_First, _Last);

if (const size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
if (size_t _Avx_size = _Size_bytes & ~size_t{0x1F}; _Avx_size != 0 && _Use_avx2()) {
const __m256i _Comparand = _Traits::_Set_avx(_Val);
const void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Avx_size);

do {
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
const int _Bingo = _mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
_Advance_bytes(_First, 32);
} while (_First != _Stop_at);
for (;;) {
if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
_Advance_bytes(_Stop_at, _Avx_size);
} else {
constexpr size_t _Max_portion_size = _Traits::_Max_count * 32;
const size_t _Portion_size = _Avx_size < _Max_portion_size ? _Avx_size : _Max_portion_size;
_Advance_bytes(_Stop_at, _Portion_size);
_Avx_size -= _Portion_size;
}

__m256i _Count_vector = _mm256_setzero_si256();

do {
const __m256i _Data = _mm256_loadu_si256(static_cast<const __m256i*>(_First));
const __m256i _Mask = _Traits::_Cmp_avx(_Data, _Comparand);
_Count_vector = _Traits::_Sub_avx(_Count_vector, _Mask);
_Advance_bytes(_First, 32);
} while (_First != _Stop_at);

_Result += _Traits::_Reduce_avx(_Count_vector);

if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
break;
} else {
if (_Avx_size == 0) {
break;
}
}
}

if (const size_t _Avx_tail_size = _Size_bytes & 0x1C; _Avx_tail_size != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Avx_tail_size >> 2);
const __m256i _Data = _mm256_maskload_epi32(static_cast<const int*>(_First), _Tail_mask);
const int _Bingo =
_mm256_movemask_epi8(_mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
const __m256i _Mask = _mm256_and_si256(_Traits::_Cmp_avx(_Data, _Comparand), _Tail_mask);
const int _Bingo = _mm256_movemask_epi8(_Mask);
const size_t _Tail_count = __popcnt(_Bingo); // Assume available with SSE4.2
_Result += _Tail_count / sizeof(_Ty);
_Advance_bytes(_First, _Avx_tail_size);
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414

_Result >>= _Traits::_Shift;

if constexpr (sizeof(_Ty) >= 4) {
return _Result;
}
} else if (const size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
} else if (size_t _Sse_size = _Size_bytes & ~size_t{0xF}; _Sse_size != 0 && _Use_sse42()) {
const __m128i _Comparand = _Traits::_Set_sse(_Val);
const void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Sse_size);

do {
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
const int _Bingo = _mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand));
_Result += __popcnt(_Bingo); // Assume available with SSE4.2
_Advance_bytes(_First, 16);
} while (_First != _Stop_at);
for (;;) {
if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
_Advance_bytes(_Stop_at, _Sse_size);
} else {
constexpr size_t _Max_portion_size = _Traits::_Max_count * 16;
const size_t _Portion_size = _Sse_size < _Max_portion_size ? _Sse_size : _Max_portion_size;
_Advance_bytes(_Stop_at, _Portion_size);
_Sse_size -= _Portion_size;
}

__m128i _Count_vector = _mm_setzero_si128();

do {
const __m128i _Data = _mm_loadu_si128(static_cast<const __m128i*>(_First));
const __m128i _Mask = _Traits::_Cmp_sse(_Data, _Comparand);
_Count_vector = _Traits::_Sub_sse(_Count_vector, _Mask);
_Advance_bytes(_First, 16);
} while (_First != _Stop_at);

_Result >>= _Traits::_Shift;
_Result += _Traits::_Reduce_sse(_Count_vector);

if constexpr (sizeof(_Ty) >= sizeof(size_t)) {
break;
} else {
if (_Sse_size == 0) {
break;
}
}
}
}
#endif // !_M_ARM64EC

Expand Down Expand Up @@ -2549,22 +2706,22 @@ const void* __stdcall __std_find_last_trivial_8(

__declspec(noalias) size_t
__stdcall __std_count_trivial_1(const void* const _First, const void* const _Last, const uint8_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_1>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_1>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_2(const void* const _First, const void* const _Last, const uint16_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_2>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_2>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_4(const void* const _First, const void* const _Last, const uint32_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_4>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_4>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_count_trivial_8(const void* const _First, const void* const _Last, const uint64_t _Val) noexcept {
return __std_count_trivial_impl<_Find_traits_8>(_First, _Last, _Val);
return __std_count_trivial_impl<_Count_traits_8>(_First, _Last, _Val);
}

const void* __stdcall __std_find_first_of_trivial_1(
Expand Down
36 changes: 36 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ void test_case_count(const vector<T>& input, T v) {
auto expected = last_known_good_count(input.begin(), input.end(), v);
auto actual = count(input.begin(), input.end(), v);
assert(expected == actual);
#if _HAS_CXX20
auto actual_r = ranges::count(input, v);
assert(actual_r == expected);
#endif // _HAS_CXX20
}

template <class T>
void test_count_zero(const vector<T>& input, const ptrdiff_t n) {
const auto first = input.begin();
const auto last = first + n;

assert(count(first, last, T{0}) == n);
#if _HAS_CXX20
assert(ranges::count(first, last, T{0}) == n);
#endif // _HAS_CXX20
}

template <class T>
Expand All @@ -96,6 +111,27 @@ void test_count(mt19937_64& gen) {
input.push_back(static_cast<T>(dis(gen)));
test_case_count(input, static_cast<T>(dis(gen)));
}

{
input.assign(1'000'000, T{0});

// test that counters don't overflow
test_count_zero(input, 1'000'000);

// Test the AVX2 maximum portion followed by all possible tail lengths, for 1-byte and 2-byte elements.
// It's okay to test these lengths for other elements, or other instruction sets.
for (ptrdiff_t i = 8'160; i < 8'192; ++i) {
test_count_zero(input, i);
}

for (ptrdiff_t i = 524'272; i < 524'288; ++i) {
test_count_zero(input, i);
}

// Test a random length.
uniform_int_distribution<ptrdiff_t> len(0, 999'999);
test_count_zero(input, len(gen));
}
}

template <class FwdIt, class T>
Expand Down