Skip to content

Commit

Permalink
Fix reference compute mean impl, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
greole committed Oct 19, 2023
1 parent d152e73 commit ed9c4b7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 4 additions & 4 deletions reference/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,11 @@ void compute_mean(std::shared_ptr<const ReferenceExecutor> exec,
result->at(0, j) = zero<ValueType>();
}

for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
result->at(0, i) += x->at(i, j);
for (size_type i = 0; i < x->get_size()[1]; ++i) {
for (size_type j = 0; j < x->get_size()[0]; ++j) {
result->at(0, i) += x->at(j, i);
}
result->at(0, i) /= static_cast<ValueType_nc>(x->get_size()[1]);
result->at(0, i) /= static_cast<ValueType_nc>(x->get_size()[0]);
}
}

Expand Down
8 changes: 8 additions & 0 deletions reference/test/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <complex>
#include <memory>
#include <numeric>
#include <random>


Expand Down Expand Up @@ -704,6 +705,13 @@ TYPED_TEST(Dense, ComputesMean)
{
using Mtx = typename TestFixture::Mtx;
using T = typename TestFixture::value_type;

auto iota = Mtx::create(this->exec, gko::dim<2>{10, 1});
std::iota(iota->get_values(), iota->get_values() + 10, 1);
auto iota_result = Mtx::create(this->exec, gko::dim<2>{1, 1});
iota->compute_mean(iota_result.get());
GKO_EXPECT_NEAR(iota_result->at(0, 0), T{5.5}, r<T>::value * 10);

auto result = Mtx::create(this->exec, gko::dim<2>{1, 3});

this->mtx4->compute_mean(result.get());
Expand Down

0 comments on commit ed9c4b7

Please sign in to comment.