Skip to content

Commit

Permalink
update rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Mar 25, 2023
1 parent bbb0fcb commit f5113e2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
25 changes: 14 additions & 11 deletions core/test/base/extended_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,21 @@ TEST_F(FloatToHalf, TruncatesSmallNumber)
}


TEST_F(FloatToHalf, TruncatesLargeNumber)
TEST_F(FloatToHalf, TruncatesLargeNumberRoundToEven)
{
half x = create_from_bits("1" "10001110" "10010011111000010000100");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// TODO: sycl::half seems to did rounding, but ours just truncates
ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001010000"));
#else
ASSERT_EQ(get_bits(x), get_bits("1" "11110" "1001001111"));
#endif

half neg_x = create_from_bits("1" "10001110" "10010011111000010000100");
half neg_x2 = create_from_bits("1" "10001110" "10010011101000010000100");
half x = create_from_bits("0" "10001110" "10010011111000010000100");
half x2 = create_from_bits("0" "10001110" "10010011101000010000100");
half x3 = create_from_bits("0" "10001110" "10010011101000000000000");
half x4 = create_from_bits("0" "10001110" "10010011111000000000000");

EXPECT_EQ(get_bits(x), get_bits("0" "11110" "1001010000"));
EXPECT_EQ(get_bits(x2), get_bits("0" "11110" "1001001111"));
EXPECT_EQ(get_bits(x3), get_bits("0" "11110" "1001001110"));
EXPECT_EQ(get_bits(x4), get_bits("0" "11110" "1001010000"));
EXPECT_EQ(get_bits(neg_x), get_bits("1" "11110" "1001010000"));
EXPECT_EQ(get_bits(neg_x2), get_bits("1" "11110" "1001001111"));
}


Expand Down
17 changes: 15 additions & 2 deletions include/ginkgo/core/base/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,21 @@ class half {
// TODO: handle denormals
return conv::shift_sign(data_);
} else {
return conv::shift_sign(data_) | exp |
conv::shift_significand(data_);
// Rounding to even
const auto result = conv::shift_sign(data_) | exp |
conv::shift_significand(data_);
// return result + ((result & 1) &&
// ((data_ >> (f32_traits::significand_bits -
// f16_traits::significand_bits - 1)) &
// 1));
const auto tail =
data_ & static_cast<f32_traits::bits_type>(
(1 << conv::significand_offset) - 1);

constexpr auto half = static_cast<f32_traits::bits_type>(
1 << (conv::significand_offset - 1));
return result +
(tail > half || ((tail == half) && (result & 1)));
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions test/components/fill_array_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class FillArray : public CommonTestFixture {
protected:
using value_type = T;
FillArray()
: total_size(63531),
: total_size(3000),
vals{ref, total_size},
dvals{exec, total_size},
seqs{ref, total_size}
Expand All @@ -68,8 +68,8 @@ class FillArray : public CommonTestFixture {
gko::array<value_type> seqs;
};

TYPED_TEST_SUITE(FillArray, gko::test::ValueAndIndexTypes,
TypenameNameGenerator);
using LIST = ::testing::Types<gko::half>;
TYPED_TEST_SUITE(FillArray, LIST, TypenameNameGenerator);


TYPED_TEST(FillArray, EqualsReference)
Expand All @@ -88,5 +88,10 @@ TYPED_TEST(FillArray, FillSeqEqualsReference)
gko::kernels::EXEC_NAMESPACE::components::fill_seq_array(
this->exec, this->dvals.get_data(), this->total_size);

this->dvals.set_executor(this->ref);
for (gko::size_type i = 2000; i < this->total_size; i++) {
std::cout << i << " " << this->seqs.get_data()[i] << " device "
<< this->dvals.get_data()[i] << std::endl;
}
GKO_ASSERT_ARRAY_EQ(this->seqs, this->dvals);
}

0 comments on commit f5113e2

Please sign in to comment.