diff --git a/reference/test/matrix/batch_ell_kernels.cpp b/reference/test/matrix/batch_ell_kernels.cpp index 81f189c3e02..d0e70bf5552 100644 --- a/reference/test/matrix/batch_ell_kernels.cpp +++ b/reference/test/matrix/batch_ell_kernels.cpp @@ -128,6 +128,21 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector) } +TYPED_TEST(Ell, ConstAppliesToBatchMultiVector) +{ + using T = typename TestFixture::value_type; + using BMtx = typename TestFixture::BMtx; + + static_cast(this->mtx_0.get())->apply(this->b_0, this->x_0); + + this->mtx_00->apply(this->b_00.get(), this->x_00.get()); + this->mtx_01->apply(this->b_01.get(), this->x_01.get()); + auto res = gko::batch::unbatch>(this->x_0.get()); + GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r::value); + GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r::value); +} + + TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) { using BMtx = typename TestFixture::BMtx; @@ -154,6 +169,32 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) } +TYPED_TEST(Ell, ConstAppliesLinearCombinationToBatchMultiVector) +{ + using BMtx = typename TestFixture::BMtx; + using BMVec = typename TestFixture::BMVec; + using DenseMtx = typename TestFixture::DenseMtx; + using T = typename TestFixture::value_type; + auto alpha = gko::batch::initialize({{1.5}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{2.5}, {-4.0}}, this->exec); + auto alpha0 = gko::initialize({1.5}, this->exec); + auto alpha1 = gko::initialize({-1.0}, this->exec); + auto beta0 = gko::initialize({2.5}, this->exec); + auto beta1 = gko::initialize({-4.0}, this->exec); + + static_cast(this->mtx_0.get()) + ->apply(alpha.get(), this->b_0.get(), beta.get(), this->x_0.get()); + + this->mtx_00->apply(alpha0.get(), this->b_00.get(), beta0.get(), + this->x_00.get()); + this->mtx_01->apply(alpha1.get(), this->b_01.get(), beta1.get(), + this->x_01.get()); + auto res = gko::batch::unbatch>(this->x_0.get()); + GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r::value); + GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r::value); +} + + TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) { using BMVec = typename TestFixture::BMVec;