Skip to content

Commit

Permalink
add the const apply check
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 20, 2023
1 parent 612a732 commit 0726ab6
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector)
}


TYPED_TEST(Ell, ConstAppliesToBatchMultiVector)
{
using T = typename TestFixture::value_type;
using BMtx = typename TestFixture::BMtx;

static_cast<const BMtx*>(this->mtx_0.get())->apply(this->b_0, this->x_0);

this->mtx_00->apply(this->b_00.get(), this->x_00.get());
this->mtx_01->apply(this->b_01.get(), this->x_01.get());
auto res = gko::batch::unbatch<gko::batch::MultiVector<T>>(this->x_0.get());
GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r<T>::value);
GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r<T>::value);
}


TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)
{
using BMtx = typename TestFixture::BMtx;
Expand All @@ -154,6 +169,32 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)
}


TYPED_TEST(Ell, ConstAppliesLinearCombinationToBatchMultiVector)
{
using BMtx = typename TestFixture::BMtx;
using BMVec = typename TestFixture::BMVec;
using DenseMtx = typename TestFixture::DenseMtx;
using T = typename TestFixture::value_type;
auto alpha = gko::batch::initialize<BMVec>({{1.5}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{2.5}, {-4.0}}, this->exec);
auto alpha0 = gko::initialize<DenseMtx>({1.5}, this->exec);
auto alpha1 = gko::initialize<DenseMtx>({-1.0}, this->exec);
auto beta0 = gko::initialize<DenseMtx>({2.5}, this->exec);
auto beta1 = gko::initialize<DenseMtx>({-4.0}, this->exec);

static_cast<const BMtx*>(this->mtx_0.get())
->apply(alpha.get(), this->b_0.get(), beta.get(), this->x_0.get());

this->mtx_00->apply(alpha0.get(), this->b_00.get(), beta0.get(),
this->x_00.get());
this->mtx_01->apply(alpha1.get(), this->b_01.get(), beta1.get(),
this->x_01.get());
auto res = gko::batch::unbatch<gko::batch::MultiVector<T>>(this->x_0.get());
GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r<T>::value);
GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r<T>::value);
}


TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols)
{
using BMVec = typename TestFixture::BMVec;
Expand Down

0 comments on commit 0726ab6

Please sign in to comment.