diff --git a/CMakeLists.txt b/CMakeLists.txt index c66713273..d9a46f878 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -219,7 +219,7 @@ if(SEAL_USE_INTEL_HEXL) message(STATUS "Intel HEXL: download ...") seal_fetch_thirdparty_content(ExternalIntelHEXL) else() - find_package(HEXL 1.0.1) + find_package(HEXL 1.1.0) if (NOT TARGET HEXL::hexl) FATAL_ERROR("Intel HEXL: not found") endif() @@ -461,8 +461,12 @@ else() endif() if(SEAL_USE_INTEL_HEXL) - target_include_directories(seal_shared PRIVATE $) - target_link_libraries(seal_shared PRIVATE hexl) + get_target_property( + HEXL_INCLUDE_DIR + HEXL::hexl + INTERFACE_INCLUDE_DIRECTORIES) + target_include_directories(seal_shared PUBLIC ${HEXL_INCLUDE_DIR}) + target_link_libraries(seal_shared PRIVATE HEXL::hexl) endif() endif() diff --git a/README.md b/README.md index 842485d19..028782118 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ The optional dependencies and their tested versions (other versions may work as | Optional dependency | Tested version | Use | | ------------------------------------------------------ | -------------- | ------------------------------------------------ | -| [Intel HEXL](https://github.com/intel/hexl) | 1.0.1 | Acceleration of low-level kernels | +| [Intel HEXL](https://github.com/intel/hexl) | 1.1.0 | Acceleration of low-level kernels | | [Microsoft GSL](https://github.com/microsoft/GSL) | 3.1.0 | API extensions | | [ZLIB](https://github.com/madler/zlib) | 1.2.11 | Compressed serialization | | [Zstandard](https://github.com/facebook/zstd) | 1.4.5 | Compressed serialization (much faster than ZLIB) | diff --git a/cmake/ExternalIntelHEXL.cmake b/cmake/ExternalIntelHEXL.cmake index 3c32cc9a5..8a1f971d8 100644 --- a/cmake/ExternalIntelHEXL.cmake +++ b/cmake/ExternalIntelHEXL.cmake @@ -4,8 +4,8 @@ FetchContent_Declare( hexl PREFIX hexl - GIT_REPOSITORY https://github.com/intel/hexl.git - GIT_TAG v1.0.1 + GIT_REPOSITORY https://github.com/intel/hexl + GIT_TAG c28943d # v1.1.0 ) FetchContent_GetProperties(hexl) diff --git a/native/src/seal/util/intel_seal_ext.h b/native/src/seal/util/intel_seal_ext.h index 98708da1f..835760181 100644 --- a/native/src/seal/util/intel_seal_ext.h +++ b/native/src/seal/util/intel_seal_ext.h @@ -4,12 +4,119 @@ #pragma once #ifdef SEAL_USE_INTEL_HEXL +#include "seal/memorymanager.h" #include "seal/util/locks.h" #include #include "hexl/hexl.hpp" namespace intel { + namespace hexl + { + // Single threaded SEAL allocator adapter + template <> + struct NTT::AllocatorAdapter + : public AllocatorInterface> + { + AllocatorAdapter(seal::MemoryPoolHandle handle) : handle_(std::move(handle)) + {} + + ~AllocatorAdapter() + {} + + // interface implementations + void *allocate_impl(std::size_t bytes_count) + { + cache_.push_back(static_cast(handle_).get_for_byte_count(bytes_count)); + return cache_.back().get(); + } + + void deallocate_impl(void *p, std::size_t n) + { + (void)n; + auto it = std::remove_if( + cache_.begin(), cache_.end(), + [p](const seal::util::Pointer &seal_pointer) { return p == seal_pointer.get(); }); + +#ifdef SEAL_DEBUG + if (it == cache_.end()) + { + throw std::logic_error("Inconsistent single-threaded allocator cache"); + } +#endif + cache_.erase(it, cache_.end()); + } + + private: + seal::MemoryPoolHandle handle_; + std::vector> cache_; + }; + + // Thread safe policy + struct SimpleThreadSafePolicy + { + SimpleThreadSafePolicy() : m_ptr(std::make_unique()) + {} + + std::unique_lock locker() + { + return std::unique_lock{ *m_ptr }; + }; + + private: + std::unique_ptr m_ptr; + }; + + // Multithreaded SEAL allocator adapter + template <> + struct NTT::AllocatorAdapter + : public AllocatorInterface> + { + AllocatorAdapter(seal::MemoryPoolHandle handle, SimpleThreadSafePolicy &&policy) + : handle_(std::move(handle)), policy_(std::move(policy)) + {} + + ~AllocatorAdapter() + {} + // interface implementations + void *allocate_impl(std::size_t bytes_count) + { + { + // to prevent inline optimization with deadlock + auto accessor = policy_.locker(); + cache_.push_back(static_cast(handle_).get_for_byte_count(bytes_count)); + return cache_.back().get(); + } + } + + void deallocate_impl(void *p, std::size_t n) + { + (void)n; + { + // to prevent inline optimization with deadlock + auto accessor = policy_.locker(); + auto it = std::remove_if( + cache_.begin(), cache_.end(), [p](const seal::util::Pointer &seal_pointer) { + return p == seal_pointer.get(); + }); + +#ifdef SEAL_DEBUG + if (it == cache_.end()) + { + throw std::logic_error("Inconsistent multi-threaded allocator cache"); + } +#endif + cache_.erase(it, cache_.end()); + } + } + + private: + seal::MemoryPoolHandle handle_; + SimpleThreadSafePolicy policy_; + std::vector> cache_; + }; + } // namespace hexl + namespace seal_ext { struct HashPair @@ -61,7 +168,9 @@ namespace intel auto ntt_it = ntt_cache_.find(key); if (ntt_it == ntt_cache_.end()) { - ntt_it = ntt_cache_.emplace(std::move(key), intel::hexl::NTT(N, modulus, root)).first; + intel::hexl::NTT ntt( + N, modulus, root, seal::MemoryManager::GetPool(), intel::hexl::SimpleThreadSafePolicy{}); + ntt_it = ntt_cache_.emplace(std::move(key), std::move(ntt)).first; } return ntt_it->second; } diff --git a/native/src/seal/util/polyarithsmallmod.cpp b/native/src/seal/util/polyarithsmallmod.cpp index 0ea553961..8bc05d5a5 100644 --- a/native/src/seal/util/polyarithsmallmod.cpp +++ b/native/src/seal/util/polyarithsmallmod.cpp @@ -37,10 +37,14 @@ namespace seal } #endif +#ifdef SEAL_USE_INTEL_HEXL + intel::hexl::EltwiseAddMod(result, poly, scalar, coeff_count, modulus.value()); +#else SEAL_ITERATE(iter(poly, result), coeff_count, [&](auto I) { const uint64_t x = get<0>(I); get<1>(I) = add_uint_mod(x, scalar, modulus); }); +#endif } void sub_poly_scalar_coeffmod( @@ -65,10 +69,14 @@ namespace seal } #endif +#ifdef SEAL_USE_INTEL_HEXL + intel::hexl::EltwiseSubMod(result, poly, scalar, coeff_count, modulus.value()); +#else SEAL_ITERATE(iter(poly, result), coeff_count, [&](auto I) { const uint64_t x = get<0>(I); get<1>(I) = sub_uint_mod(x, scalar, modulus); }); +#endif } void multiply_poly_scalar_coeffmod( diff --git a/native/src/seal/util/polyarithsmallmod.h b/native/src/seal/util/polyarithsmallmod.h index b736996ca..888af29a1 100644 --- a/native/src/seal/util/polyarithsmallmod.h +++ b/native/src/seal/util/polyarithsmallmod.h @@ -39,8 +39,13 @@ namespace seal throw std::invalid_argument("modulus"); } #endif + +#ifdef SEAL_USE_INTEL_HEXL + intel::hexl::EltwiseReduceMod(result, poly, coeff_count, modulus.value(), 0, 1); +#else SEAL_ITERATE( iter(poly, result), coeff_count, [&](auto I) { get<1>(I) = barrett_reduce_64(get<0>(I), modulus); }); +#endif } inline void modulo_poly_coeffs( @@ -314,7 +319,11 @@ namespace seal throw std::invalid_argument("result"); } #endif + const uint64_t modulus_value = modulus.value(); +#ifdef SEAL_USE_INTEL_HEXL + intel::hexl::EltwiseSubMod(result, operand1, operand2, coeff_count, modulus_value); +#else SEAL_ITERATE(iter(operand1, operand2, result), coeff_count, [&](auto I) { #ifdef SEAL_DEBUG if (get<0>(I) >= modulus_value) @@ -330,6 +339,7 @@ namespace seal std::int64_t borrow = sub_uint64(get<0>(I), get<1>(I), &temp_result); get<2>(I) = temp_result + (modulus_value & static_cast(-borrow)); }); +#endif } inline void sub_poly_coeffmod( diff --git a/native/tests/seal/modulus.cpp b/native/tests/seal/modulus.cpp index 41970cb9b..d639800bb 100644 --- a/native/tests/seal/modulus.cpp +++ b/native/tests/seal/modulus.cpp @@ -104,7 +104,13 @@ namespace sealtest stringstream stream; Modulus mod; +#ifdef SEAL_USE_ZLIB compr_mode_type compr_mode = compr_mode_type::zlib; +#elif defined(SEAL_USE_ZSTD) + compr_mode_type compr_mode = compr_mode_type::zstd; +#else + compr_mode_type compr_mode = compr_mode_type::none; +#endif mod.save(stream, compr_mode); Modulus mod2; diff --git a/native/tests/seal/serialization.cpp b/native/tests/seal/serialization.cpp index eb31c4f91..84ed1ebdd 100644 --- a/native/tests/seal/serialization.cpp +++ b/native/tests/seal/serialization.cpp @@ -52,11 +52,15 @@ namespace sealtest Serialization::SEALHeader header; ASSERT_TRUE(Serialization::IsValidHeader(header)); - header.compr_mode = (compr_mode_type)0x01; +#ifdef SEAL_USE_ZLIB + header.compr_mode = compr_mode_type::zlib; ASSERT_TRUE(Serialization::IsValidHeader(header)); +#endif - header.compr_mode = (compr_mode_type)0x02; +#ifdef SEAL_USE_ZSTD + compr_mode_type compr_mode = compr_mode_type::zstd; ASSERT_TRUE(Serialization::IsValidHeader(header)); +#endif Serialization::SEALHeader invalid_header; invalid_header.magic = 0x1212;