Skip to content

Commit

Permalink
Review udpates.
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
  • Loading branch information
pratikvn and MarcelKoch committed Sep 27, 2023
1 parent fe3a21a commit 811b812
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 136 deletions.
72 changes: 7 additions & 65 deletions core/test/base/batch_lin_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,39 +57,21 @@ class DummyBatchLinOp : public gko::batch::EnableBatchLinOp<DummyBatchLinOp>,
: gko::batch::EnableBatchLinOp<DummyBatchLinOp>(exec, size)
{}

void access() const { last_access = this->get_executor(); }

mutable std::shared_ptr<const gko::Executor> last_access;
mutable std::shared_ptr<const gko::Executor> last_b_access;
mutable std::shared_ptr<const gko::Executor> last_x_access;
mutable std::shared_ptr<const gko::Executor> last_alpha_access;
mutable std::shared_ptr<const gko::Executor> last_beta_access;
int called = 0;

protected:
void apply_impl(const gko::batch::BatchLinOp* b,
gko::batch::BatchLinOp* x) const override
{
this->access();
static_cast<const DummyBatchLinOp*>(b)->access();
static_cast<const DummyBatchLinOp*>(x)->access();
last_b_access = b->get_executor();
last_x_access = x->get_executor();
this->called = 1;
}

void apply_impl(const gko::batch::BatchLinOp* alpha,
const gko::batch::BatchLinOp* b,
const gko::batch::BatchLinOp* beta,
gko::batch::BatchLinOp* x) const override
{
this->access();
static_cast<const DummyBatchLinOp*>(alpha)->access();
static_cast<const DummyBatchLinOp*>(b)->access();
static_cast<const DummyBatchLinOp*>(beta)->access();
static_cast<const DummyBatchLinOp*>(x)->access();
last_alpha_access = alpha->get_executor();
last_b_access = b->get_executor();
last_beta_access = beta->get_executor();
last_x_access = x->get_executor();
this->called = 2;
}
};

Expand Down Expand Up @@ -156,31 +138,31 @@ TEST_F(EnableBatchLinOp, CallsApplyImpl)
{
op->apply(b, x);

ASSERT_EQ(op->last_access, ref2);
ASSERT_EQ(op->called, 1);
}


TEST_F(EnableBatchLinOp, CallsApplyImplForBatch)
{
op2->apply(b2, x2);

ASSERT_EQ(op2->last_access, ref2);
ASSERT_EQ(op2->called, 1);
}


TEST_F(EnableBatchLinOp, CallsExtendedApplyImpl)
{
op->apply(alpha, b, beta, x);

ASSERT_EQ(op->last_access, ref2);
ASSERT_EQ(op->called, 2);
}


TEST_F(EnableBatchLinOp, CallsExtendedApplyImplBatch)
{
op2->apply(alpha2, b2, beta2, x2);

ASSERT_EQ(op2->last_access, ref2);
ASSERT_EQ(op2->called, 2);
}


Expand Down Expand Up @@ -283,46 +265,6 @@ TEST_F(EnableBatchLinOp, ExtendedApplyFailsOnWrongBetaDimension)
}


TEST_F(EnableBatchLinOp, ApplyDoesNotCopyBetweenSameMemory)
{
op->apply(b, x);

ASSERT_EQ(op->last_b_access, ref);
ASSERT_EQ(op->last_x_access, ref);
}


TEST_F(EnableBatchLinOp, ApplyNoCopyBackBetweenSameMemory)
{
op->apply(b, x);

ASSERT_EQ(b->last_access, ref);
ASSERT_EQ(x->last_access, ref);
}


TEST_F(EnableBatchLinOp, ExtendedApplyDoesNotCopyBetweenSameMemory)
{
op->apply(alpha, b, beta, x);

ASSERT_EQ(op->last_alpha_access, ref);
ASSERT_EQ(op->last_b_access, ref);
ASSERT_EQ(op->last_beta_access, ref);
ASSERT_EQ(op->last_x_access, ref);
}


TEST_F(EnableBatchLinOp, ExtendedApplyNoCopyBackBetweenSameMemory)
{
op->apply(alpha, b, beta, x);

ASSERT_EQ(alpha->last_access, ref);
ASSERT_EQ(b->last_access, ref);
ASSERT_EQ(beta->last_access, ref);
ASSERT_EQ(x->last_access, ref);
}


template <typename T = int>
class DummyBatchLinOpWithFactory
: public gko::batch::EnableBatchLinOp<DummyBatchLinOpWithFactory<T>> {
Expand Down
14 changes: 7 additions & 7 deletions include/ginkgo/core/base/batch_lin_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,6 @@ class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
return size_.get_num_batch_items();
}

/**
* Sets the size of the batch operator.
*
* @param size to be set
*/
void set_size(const batch_dim<2>& size) { size_ = size; }

/**
* Returns the size of the batch operator.
*
Expand All @@ -209,6 +202,13 @@ class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
const batch_dim<2>& get_size() const noexcept { return size_; }

protected:
/**
* Sets the size of the batch operator.
*
* @param size to be set
*/
void set_size(const batch_dim<2>& size) { size_ = size; }

/**
* Creates a batch operator with uniform batches.
*
Expand Down
108 changes: 44 additions & 64 deletions include/ginkgo/core/base/exception_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,26 +315,38 @@ inline size_type get_num_batch_items(const T& obj)
}


/**
* Asserts that _op1 and _op2 have equal number of items in the batch
*
* @throw ValueMismatch if _op1 and _op2 do not have equal number of items
*/
#define GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} \
}


/**
* Asserts that _op1 can be applied to _op2.
*
* @throw DimensionMismatch if _op1 cannot be applied to _op2.
*/
#define GKO_ASSERT_BATCH_CONFORMANT(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2); \
auto equal_inner_size = \
::gko::detail::get_batch_size(_op1).get_common_size()[1] == \
::gko::detail::get_batch_size(_op2).get_common_size()[0]; \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} else if (!equal_inner_size) { \
if (!equal_inner_size) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
Expand All @@ -354,19 +366,11 @@ inline size_type get_num_batch_items(const T& obj)
*/
#define GKO_ASSERT_BATCH_REVERSE_CONFORMANT(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2); \
auto equal_outer_size = \
::gko::detail::get_batch_size(_op1).get_common_size()[0] == \
::gko::detail::get_batch_size(_op2).get_common_size()[1]; \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} else if (!equal_outer_size) { \
if (!equal_outer_size) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
Expand All @@ -386,19 +390,11 @@ inline size_type get_num_batch_items(const T& obj)
*/
#define GKO_ASSERT_BATCH_EQUAL_ROWS(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2); \
auto equal_rows = \
::gko::detail::get_batch_size(_op1).get_common_size()[0] == \
::gko::detail::get_batch_size(_op2).get_common_size()[0]; \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} else if (!equal_rows) { \
if (!equal_rows) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
Expand All @@ -419,19 +415,11 @@ inline size_type get_num_batch_items(const T& obj)
*/
#define GKO_ASSERT_BATCH_EQUAL_COLS(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2); \
auto equal_cols = \
::gko::detail::get_batch_size(_op1).get_common_size()[1] == \
::gko::detail::get_batch_size(_op2).get_common_size()[1]; \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} else if (!equal_cols) { \
if (!equal_cols) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
Expand All @@ -450,30 +438,22 @@ inline size_type get_num_batch_items(const T& obj)
* @throw DimensionMismatch if `_op1` and `_op2` differ in the number of
* rows or columns
*/
#define GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(_op1, _op2) \
{ \
auto equal_num_items = \
::gko::detail::get_batch_size(_op1).get_num_batch_items() == \
::gko::detail::get_batch_size(_op2).get_num_batch_items(); \
auto equal_size = \
::gko::detail::get_batch_size(_op1).get_common_size() == \
::gko::detail::get_batch_size(_op2).get_common_size(); \
if (!equal_num_items) { \
throw ::gko::ValueMismatch( \
__FILE__, __LINE__, __func__, \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
::gko::detail::get_batch_size(_op2).get_num_batch_items(), \
"expected equal number of batch items"); \
} else if (!equal_size) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
::gko::detail::get_batch_size(_op1).get_common_size()[1], \
#_op2, \
::gko::detail::get_batch_size(_op2).get_common_size()[0], \
::gko::detail::get_batch_size(_op2).get_common_size()[1], \
"expected matching size among all batch items"); \
} \
#define GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(_op1, _op2) \
{ \
GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(_op1, _op2); \
auto equal_size = \
::gko::detail::get_batch_size(_op1).get_common_size() == \
::gko::detail::get_batch_size(_op2).get_common_size(); \
if (!equal_size) { \
throw ::gko::DimensionMismatch( \
__FILE__, __LINE__, __func__, #_op1, \
::gko::detail::get_batch_size(_op1).get_common_size()[0], \
::gko::detail::get_batch_size(_op1).get_common_size()[1], \
#_op2, \
::gko::detail::get_batch_size(_op2).get_common_size()[0], \
::gko::detail::get_batch_size(_op2).get_common_size()[1], \
"expected matching size among all batch items"); \
} \
}


Expand Down

0 comments on commit 811b812

Please sign in to comment.