Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize basic_string::find_first_of #4744

Merged
merged 29 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f85a9d0
benchmark basic_string::find_first_of
AlexGuteniev Jun 23, 2024
20f9f87
move `_Find_first_of_vectorized` to `<xutility>`
AlexGuteniev Jun 23, 2024
3ba668a
attach the optimization
AlexGuteniev Jun 23, 2024
7a5a7bf
unsigned
AlexGuteniev Jun 23, 2024
72d3d22
benchmark out-of-table case
AlexGuteniev Jun 23, 2024
2608a58
Vectorize large char type only for small needle
AlexGuteniev Jun 23, 2024
22b3494
coverage
AlexGuteniev Jun 23, 2024
6928323
fix benchmark bug
AlexGuteniev Jun 23, 2024
cd2ce42
stray
AlexGuteniev Jun 23, 2024
8cb9a23
Merge remote-tracking branch 'upstream/main' into string_attached
AlexGuteniev Aug 25, 2024
c477fbb
add `if constexpr` warning suppression
AlexGuteniev Aug 25, 2024
a97c791
Merge branch 'main' into string_attached
StephanTLavavej Aug 29, 2024
ff285ad
Step 1: Copy the "return no match" upwards.
StephanTLavavej Aug 29, 2024
2cf1f27
Step 2: Flip `_Matches._Mark` control flow.
StephanTLavavej Aug 29, 2024
ac5a88e
Step 3: `_Use_bitmap` avoids code duplication.
StephanTLavavej Aug 29, 2024
e6b015d
Extract `_Elem` instead of `_Elem_size`.
StephanTLavavej Aug 29, 2024
849092d
Fall through to "not special" instead of calling self with `_Special …
StephanTLavavej Aug 29, 2024
5dfadec
Extract `_Hay_start`, `_Hay_end`.
StephanTLavavej Aug 29, 2024
82ffddf
Comment: Mention "vectorized or serial" fallbacks.
StephanTLavavej Aug 29, 2024
44e8f8c
Comment: Reverse to "vectorization outperforms" (and slightly reduce …
StephanTLavavej Aug 29, 2024
757d784
Comment: Add "found a match".
StephanTLavavej Aug 29, 2024
04d95f2
Benchmark: `AlgType::std` => `AlgType::std_func`
StephanTLavavej Aug 29, 2024
c8f93e7
Test: `expected_ptr` => `expected_iter`
StephanTLavavej Aug 29, 2024
ef2ed39
Test: Drop leading zero hexits.
StephanTLavavej Aug 29, 2024
8da299b
Lift out the `_Is_constant_evaluated()` check.
StephanTLavavej Aug 30, 2024
e162d0d
Avoid C4127 instead of suppressing
CaseyCarter Aug 30, 2024
50a5599
Fix thinko
CaseyCarter Aug 30, 2024
5eb4dfa
Revert "Fix thinko"
CaseyCarter Aug 30, 2024
19b311a
Merge pull request #7 from CaseyCarter/string_attached
AlexGuteniev Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions benchmarks/src/find_first_of.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@
#include <cstdint>
#include <cstdlib>
#include <numeric>
#include <string>
#include <type_traits>
#include <vector>

using namespace std;

template <class T>
enum class AlgType : bool { std, str_member };
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

template <AlgType Alg, class T, T Start = T{'a'}>
void bm(benchmark::State& state) {
const size_t Pos = static_cast<size_t>(state.range(0));
const size_t NSize = static_cast<size_t>(state.range(1));
const size_t HSize = Pos * 2;
const size_t Which = 0;

vector<T> h(HSize, T{'.'});
vector<T> n(NSize);
iota(n.begin(), n.end(), T{'a'});
using container = conditional_t<Alg == AlgType::str_member, basic_string<T>, vector<T>>;

container h(HSize, T{'.'});
container n(NSize, T{0});
iota(n.begin(), n.end(), Start);

if (Pos >= HSize || Which >= NSize) {
abort();
Expand All @@ -29,18 +35,28 @@ void bm(benchmark::State& state) {
h[Pos] = n[Which];

for (auto _ : state) {
benchmark::DoNotOptimize(find_first_of(h.begin(), h.end(), n.begin(), n.end()));
benchmark::DoNotOptimize(h);
benchmark::DoNotOptimize(n);
if constexpr (Alg == AlgType::str_member) {
benchmark::DoNotOptimize(h.find_first_of(n.data(), 0, n.size()));
} else {
benchmark::DoNotOptimize(find_first_of(h.begin(), h.end(), n.begin(), n.end()));
}
}
}

void common_args(auto bm) {
bm->Args({2, 3})->Args({7, 4})->Args({9, 3})->Args({22, 5})->Args({58, 2});
bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({3056, 7});
bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({1502, 23})->Args({3056, 7});
}

BENCHMARK(bm<uint8_t>)->Apply(common_args);
BENCHMARK(bm<uint16_t>)->Apply(common_args);
BENCHMARK(bm<uint32_t>)->Apply(common_args);
BENCHMARK(bm<uint64_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std, uint8_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std, uint16_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std, uint32_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::std, uint64_t>)->Apply(common_args);

BENCHMARK(bm<AlgType::str_member, char>)->Apply(common_args);
BENCHMARK(bm<AlgType::str_member, wchar_t>)->Apply(common_args);
BENCHMARK(bm<AlgType::str_member, wchar_t, L'\x03B1'>)->Apply(common_args);

BENCHMARK_MAIN();
29 changes: 29 additions & 0 deletions stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,38 @@ constexpr size_t _Traits_find_first_of(_In_reads_(_Hay_size) const _Traits_ptr_t
// in [_Haystack, _Haystack + _Hay_size), look for one of [_Needle, _Needle + _Needle_size), at/after _Start_at
if (_Needle_size != 0 && _Start_at < _Hay_size) { // room for match, look for it
if constexpr (_Special) {
#if _USE_STD_VECTOR_ALGORITHMS
bool _Try_vectorize = !_STD _Is_constant_evaluated() && _Hay_size - _Start_at > _Threshold_find_first_of;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
constexpr size_t _Elem_size = sizeof(*_Haystack);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

// Additional condition for the case where the table lookup outperforms the vectorization
if (_Try_vectorize && (_Elem_size == 1 || _Elem_size * _Needle_size <= 16)) {
AlexGuteniev marked this conversation as resolved.
Show resolved Hide resolved
const _Traits_ptr_t<_Traits> _Found = _STD _Find_first_of_vectorized(
_Haystack + _Start_at, _Haystack + _Hay_size, _Needle, _Needle + _Needle_size);

if (_Found != _Haystack + _Hay_size) {
return static_cast<size_t>(_Found - _Haystack);
} else {
return static_cast<size_t>(-1); // no match
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
_String_bitmap<typename _Traits::char_type> _Matches;
if (!_Matches._Mark(_Needle, _Needle + _Needle_size)) { // couldn't put one of the characters into the
// bitmap, fall back to the serial algorithm
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#if _USE_STD_VECTOR_ALGORITHMS
if (_Try_vectorize) {
const _Traits_ptr_t<_Traits> _Found = _STD _Find_first_of_vectorized(
_Haystack + _Start_at, _Haystack + _Hay_size, _Needle, _Needle + _Needle_size);

if (_Found != _Haystack + _Hay_size) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
return static_cast<size_t>(_Found - _Haystack);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
} else {
return static_cast<size_t>(-1); // no match
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

return _Traits_find_first_of<_Traits, false>(_Haystack, _Hay_size, _Start_at, _Needle, _Needle_size);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down
33 changes: 0 additions & 33 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,6 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -198,27 +189,6 @@ _Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Find_first_of_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 4) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 8) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
Expand All @@ -237,9 +207,6 @@ __declspec(noalias) void _Replace_vectorized(
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

// Can we activate the vector algorithms for find_first_of?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_find_first_of_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>;
Expand Down
33 changes: 33 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ const void* __stdcall __std_find_trivial_2(const void* _First, const void* _Last
const void* __stdcall __std_find_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_4(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
Expand Down Expand Up @@ -198,6 +207,30 @@ _Ty* _Find_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noe
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

template <class _Ty1, class _Ty2>
_Ty1* _Find_first_of_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 4) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 8) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
Expand Down
67 changes: 61 additions & 6 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,12 @@ void test_case_find_first_of(const vector<T>& input_haystack, const vector<T>& i
#endif // _HAS_CXX20
}

constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;

template <class T>
void test_find_first_of(mt19937_64& gen) {
constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;
using TD = conditional_t<sizeof(T) == 1, int, T>;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('a', 'z');
vector<T> input_haystack;
vector<T> input_needle;
Expand Down Expand Up @@ -310,9 +311,7 @@ void test_case_search(const vector<T>& input_haystack, const vector<T>& input_ne

template <class T>
void test_search(mt19937_64& gen) {
constexpr size_t haystackDataCount = 200;
constexpr size_t needleDataCount = 35;
using TD = conditional_t<sizeof(T) == 1, int, T>;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('0', '9');
vector<T> input_haystack;
vector<T> input_needle;
Expand Down Expand Up @@ -1024,6 +1023,61 @@ void test_bitset(mt19937_64& gen) {
test_randomized_bitset_base_count<512 - 5, 32 + 10>(gen);
}

template <class T>
void test_case_string_find_first_of(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
auto expected_ptr = last_known_good_find_first_of(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
auto expected = (expected_ptr != input_haystack.end()) ? expected_ptr - input_haystack.begin() : ptrdiff_t{-1};
auto actual = static_cast<ptrdiff_t>(input_haystack.find_first_of(input_needle.data(), 0, input_needle.size()));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
basic_string<T> input_needle;
input_haystack.reserve(haystackDataCount);
input_needle.reserve(needleDataCount);

for (;;) {
input_needle.clear();

test_case_string_find_first_of(input_haystack, input_needle);
for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_string_find_first_of(input_haystack, input_needle);
}

if (input_haystack.size() == haystackDataCount) {
break;
}

input_haystack.push_back(static_cast<T>(dis(gen)));
}
}

template <class T>
void test_basic_string(mt19937_64& gen) {
using dis_int_type = conditional_t<is_signed_v<T>, int32_t, uint32_t>;

uniform_int_distribution<dis_int_type> dis_latin('a', 'z');
test_basic_string_dis<T>(gen, dis_latin);
if constexpr (sizeof(T) >= 2) {
AlexGuteniev marked this conversation as resolved.
Show resolved Hide resolved
uniform_int_distribution<dis_int_type> dis_greek(0x0391, 0x003C9);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
test_basic_string_dis<T>(gen, dis_greek);
}
}

void test_string(mt19937_64& gen) {
test_basic_string<char>(gen);
test_basic_string<wchar_t>(gen);
#ifdef __cpp_lib_char8_t
test_basic_string<char8_t>(gen);
#endif // __cpp_lib_char8_t
test_basic_string<char16_t>(gen);
test_basic_string<char32_t>(gen);
}

void test_various_containers() {
test_one_container<vector<int>>(); // contiguous, vectorizable
test_one_container<deque<int>>(); // random-access, not vectorizable
Expand Down Expand Up @@ -1096,5 +1150,6 @@ int main() {
test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);
test_string(gen);
});
}
Loading