Skip to content

Commit

Permalink
Fix AVX512 mulmod (#50)
Browse files Browse the repository at this point in the history
* Fix AVX512 mulmod
  • Loading branch information
fboemer committed Aug 23, 2021
1 parent 3c73eff commit c73f0bb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
9 changes: 7 additions & 2 deletions hexl/eltwise/eltwise-mult-mod-avx512dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,13 @@ void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1,
const __m512i* vp_operand2 = reinterpret_cast<const __m512i*>(operand2);
__m512i* vp_result = reinterpret_cast<__m512i*>(result);

bool no_reduce_mod = (InputModFactor * modulus) < MaximumValue(50);
if (no_reduce_mod) { // No input modulus reduction necessary
// The implementation without modular reduction of the operands is correct
// as long as (InputModFactor * modulus)^2 < 2^50 * modulus, i.e.
// InputModFactor^2 * modulus < 2^50.
// See function 16 of https://arxiv.org/pdf/1407.3383.pdf.
bool no_input_reduce_mod =
(InputModFactor * InputModFactor * modulus) < (1ULL << 50);
if (no_input_reduce_mod) {
EltwiseMultModAVX512FloatLoop<1>(vp_result, vp_operand1, vp_operand2, v_u,
v_p, v_modulus, v_twice_mod, n);
} else {
Expand Down
20 changes: 20 additions & 0 deletions test/test-eltwise-mult-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ TEST(EltwiseMultMod, Big) {
CheckEqual(result, exp_out);
}

TEST(EltwiseMultMod, AVX512FloatInPlaceNoInputReduceMod) {
uint64_t input_mod_factor = 4;
uint64_t modulus = 281474976546817;
std::uniform_int_distribution<uint64_t> distrib(
0, input_mod_factor * modulus - 1);

std::vector<uint64_t> data_native(8, 998771110802331);
auto data_avx = data_native;

EltwiseMultModAVX512Float<4>(data_avx.data(), data_avx.data(),
data_avx.data(), data_avx.size(), modulus);

EltwiseMultModNative<4>(data_native.data(), data_native.data(),
data_native.data(), data_avx.size(), modulus);

CheckEqual(data_native, std::vector<uint64_t>(8, 273497826869315));
CheckEqual(data_avx, std::vector<uint64_t>(8, 273497826869315));
CheckEqual(data_avx, data_native);
}

TEST(EltwiseMultMod, avx512dqint_small) {
if (!has_avx512dq) {
GTEST_SKIP();
Expand Down

0 comments on commit c73f0bb

Please sign in to comment.