Skip to content

Commit

Permalink
fix nullptr. sycl::half have different rule in conv and full operator…
Browse files Browse the repository at this point in the history
… after 5.7
  • Loading branch information
yhmtsai committed Feb 9, 2023
1 parent 93ec7d9 commit e8bd1dc
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
32 changes: 31 additions & 1 deletion core/test/base/extended_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,27 @@ TEST_F(FloatToHalf, ConvertsNan)
{
half x = create_from_bits("0" "11111111" "00000000000000000000001");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// Sycl put the 1000000000, but ours put mask
ASSERT_EQ(get_bits(x), get_bits("0" "11111" "1000000000"));
#else
ASSERT_EQ(get_bits(x), get_bits("0" "11111" "1111111111"));
#endif
}


TEST_F(FloatToHalf, ConvertsNegNan)
{
half x = create_from_bits("1" "11111111" "00010000000000000000000");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// Sycl put the 1000000000, but ours put mask
ASSERT_EQ(get_bits(x), get_bits("1" "11111" "1000000000"));
#else
ASSERT_EQ(get_bits(x), get_bits("1" "11111" "1111111111"));
#endif
}


Expand Down Expand Up @@ -196,7 +208,13 @@ TEST_F(FloatToHalf, TruncatesLargeNumber)
{
half x = create_from_bits("1" "10001110" "10010011111000010000100");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// TODO: sycl::half seems to did rounding, but ours just truncates
ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001010000"));
#else
ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001001111"));
#endif

}

Expand Down Expand Up @@ -246,15 +264,27 @@ TEST_F(HalfToFloat, ConvertsNan)
{
float x = create_from_bits("0" "11111" "0001001000");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// sycl keeps significand
ASSERT_EQ(get_bits(x), get_bits("0" "11111111" "00010010000000000000000"));
#else
ASSERT_EQ(get_bits(x), get_bits("0" "11111111" "11111111111111111111111"));
#endif
}


TEST_F(HalfToFloat, ConvertsNegNan)
{
float x = create_from_bits("1" "11111" "0000000001");

ASSERT_EQ(get_bits(x), get_bits("1" "11111111" "11111111111111111111111"));
#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// sycl keeps significand
ASSERT_EQ(get_bits(x), get_bits("0" "11111111" "00000000010000000000000"));
#else
ASSERT_EQ(get_bits(x), get_bits("0" "11111111" "11111111111111111111111"));
#endif
}


Expand Down
10 changes: 5 additions & 5 deletions cuda/solver/common_trs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,14 @@ struct CudaSolveStruct : gko::solver::SolveStruct {

size_type work_size{};
// TODO: In nullptr is considered nullptr_t not casted to const
// ValueType* it works as expected now
// it does not work in cuda110/100 images
cusparse::buffer_size_ext(
handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs,
matrix->get_num_stored_elements(), one<ValueType>(), factor_descr,
matrix->get_const_values(), matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy,
&work_size);
matrix->get_const_col_idxs(), (const ValueType*)(nullptr), num_rhs,
solve_info, policy, &work_size);

// allocate workspace
work.resize_and_reset(work_size);
Expand All @@ -257,8 +257,8 @@ struct CudaSolveStruct : gko::solver::SolveStruct {
CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs,
matrix->get_num_stored_elements(), one<ValueType>(), factor_descr,
matrix->get_const_values(), matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy,
work.get_data());
matrix->get_const_col_idxs(), (const ValueType*)(nullptr), num_rhs,
solve_info, policy, work.get_data());
}

void solve(const matrix::Csr<ValueType, IndexType>* matrix,
Expand Down
11 changes: 9 additions & 2 deletions include/ginkgo/core/base/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ struct precision_converter<SourceType, ResultType, false> {

} // namespace detail

#ifdef SYCL_LANGUAGE_VERSION
// sycl::half miss the arithmetic operator to result float not half before 5.7
// (2022-06). It leads ? half : half/half ambiguous The same issue is reported
// in https://github.com/intel/llvm/issues/6028
#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || \
(__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
using half = sycl::half;
#else
/**
Expand Down Expand Up @@ -629,7 +634,9 @@ class complex<gko::half> {
value_type imag_;
};

#ifndef SYCL_LANGUAGE_VERSION
#if !(defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || \
(__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7)))
template <>
struct numeric_limits<gko::half> {
static constexpr bool is_specialized{true};
Expand Down
7 changes: 5 additions & 2 deletions include/ginkgo/core/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ using uint64 = std::uint64_t;
*/
using uintptr = std::uintptr_t;

#ifdef SYCL_LANGUAGE_VERSION
#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || \
(__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
using half = sycl::half;
#else
class half;
Expand Down Expand Up @@ -428,7 +430,8 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x,
_enable_macro(CudaExecutor, cuda)


#if GINKGO_ENABLE_HALF
// cuda half operation is supported from arch 5.3
#if GINKGO_ENABLE_HALF && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530)
#define GKO_ADAPT_HF(_macro) template _macro
#else
#define GKO_ADAPT_HF(_macro) \
Expand Down

0 comments on commit e8bd1dc

Please sign in to comment.