From c73f0bbbc17c7f627156aa4252381e9dbfa2a7c0 Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Mon, 23 Aug 2021 14:43:44 -0700 Subject: [PATCH] Fix AVX512 mulmod (#50) * Fix AVX512 mulmod --- hexl/eltwise/eltwise-mult-mod-avx512dq.cpp | 9 +++++++-- test/test-eltwise-mult-mod-avx512.cpp | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/hexl/eltwise/eltwise-mult-mod-avx512dq.cpp b/hexl/eltwise/eltwise-mult-mod-avx512dq.cpp index b62994e3..111b866b 100644 --- a/hexl/eltwise/eltwise-mult-mod-avx512dq.cpp +++ b/hexl/eltwise/eltwise-mult-mod-avx512dq.cpp @@ -789,8 +789,13 @@ void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1, const __m512i* vp_operand2 = reinterpret_cast(operand2); __m512i* vp_result = reinterpret_cast<__m512i*>(result); - bool no_reduce_mod = (InputModFactor * modulus) < MaximumValue(50); - if (no_reduce_mod) { // No input modulus reduction necessary + // The implementation without modular reduction of the operands is correct + // as long as (InputModFactor * modulus)^2 < 2^50 * modulus, i.e. + // InputModFactor^2 * modulus < 2^50. + // See function 16 of https://arxiv.org/pdf/1407.3383.pdf. + bool no_input_reduce_mod = + (InputModFactor * InputModFactor * modulus) < (1ULL << 50); + if (no_input_reduce_mod) { EltwiseMultModAVX512FloatLoop<1>(vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, v_twice_mod, n); } else { diff --git a/test/test-eltwise-mult-mod-avx512.cpp b/test/test-eltwise-mult-mod-avx512.cpp index 87951a9e..3285ba73 100644 --- a/test/test-eltwise-mult-mod-avx512.cpp +++ b/test/test-eltwise-mult-mod-avx512.cpp @@ -78,6 +78,26 @@ TEST(EltwiseMultMod, Big) { CheckEqual(result, exp_out); } +TEST(EltwiseMultMod, AVX512FloatInPlaceNoInputReduceMod) { + uint64_t input_mod_factor = 4; + uint64_t modulus = 281474976546817; + std::uniform_int_distribution distrib( + 0, input_mod_factor * modulus - 1); + + std::vector data_native(8, 998771110802331); + auto data_avx = data_native; + + EltwiseMultModAVX512Float<4>(data_avx.data(), data_avx.data(), + data_avx.data(), data_avx.size(), modulus); + + EltwiseMultModNative<4>(data_native.data(), data_native.data(), + data_native.data(), data_avx.size(), modulus); + + CheckEqual(data_native, std::vector(8, 273497826869315)); + CheckEqual(data_avx, std::vector(8, 273497826869315)); + CheckEqual(data_avx, data_native); +} + TEST(EltwiseMultMod, avx512dqint_small) { if (!has_avx512dq) { GTEST_SKIP();