Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the automatical strategy when shared. #559

Merged
merged 5 commits into from
Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 1 addition & 43 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,49 +128,7 @@ void Csr<ValueType, IndexType>::convert_to(
result->col_idxs_ = this->col_idxs_;
result->row_ptrs_ = this->row_ptrs_;
result->set_size(this->get_size());
auto strat = this->get_strategy().get();
using Other = Csr<next_precision<ValueType>, IndexType>;
std::shared_ptr<typename Other::strategy_type> new_strat;
// TODO clean this up as soon as we improve strategy_type
if (dynamic_cast<classical *>(strat)) {
new_strat = std::make_shared<typename Other::classical>();
} else if (dynamic_cast<merge_path *>(strat)) {
new_strat = std::make_shared<typename Other::merge_path>();
} else if (dynamic_cast<cusparse *>(strat)) {
new_strat = std::make_shared<typename Other::cusparse>();
} else if (dynamic_cast<sparselib *>(strat)) {
new_strat = std::make_shared<typename Other::sparselib>();
} else {
auto rexec = result->get_executor();
auto cuda_exec = std::dynamic_pointer_cast<const CudaExecutor>(rexec);
auto hip_exec = std::dynamic_pointer_cast<const HipExecutor>(rexec);
auto lb = dynamic_cast<load_balance *>(strat);
if (cuda_exec) {
if (lb) {
new_strat =
std::make_shared<typename Other::load_balance>(cuda_exec);
} else {
new_strat =
std::make_shared<typename Other::automatical>(cuda_exec);
}
} else if (hip_exec) {
if (lb) {
new_strat =
std::make_shared<typename Other::load_balance>(hip_exec);
} else {
new_strat =
std::make_shared<typename Other::automatical>(hip_exec);
}
} else {
// FIXME this creates a long delay
if (lb) {
new_strat = std::make_shared<typename Other::load_balance>();
} else {
new_strat = std::make_shared<typename Other::automatical>();
}
}
}
result->set_strategy(new_strat);
convert_strategy_helper(result);
}


Expand Down
10 changes: 8 additions & 2 deletions core/test/matrix/csr_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,17 @@ TYPED_TEST(CsrBuilder, UpdatesSrowOnDestruction)
using index_type = typename TestFixture::index_type;
struct mock_strategy : public Mtx::strategy_type {
virtual void process(const gko::Array<index_type> &,
gko::Array<index_type> *)
gko::Array<index_type> *) override
{
*was_called = true;
}
virtual int64_t clac_size(const int64_t nnz) { return 0; }

virtual int64_t clac_size(const int64_t nnz) override { return 0; }
tcojean marked this conversation as resolved.
Show resolved Hide resolved

virtual std::shared_ptr<typename Mtx::strategy_type> copy() override
{
return std::make_shared<mock_strategy>(*was_called);
}

mock_strategy(bool &flag) : Mtx::strategy_type(""), was_called(&flag) {}

Expand Down
25 changes: 25 additions & 0 deletions cuda/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,4 +676,29 @@ TEST_F(Csr, SortUnsortedMatrixIsEquivalentToRef)
}


TEST_F(Csr, OneAutomaticalWorksWithDifferentMatrices)
{
auto automatical = std::make_shared<Mtx::automatical>();
auto row_len_limit = std::max(automatical->nvidia_row_len_limit,
automatical->amd_row_len_limit);
auto load_balance_mtx = Mtx::create(ref);
auto classical_mtx = Mtx::create(ref);
load_balance_mtx->copy_from(
gen_mtx<Vec>(1, row_len_limit + 1000, row_len_limit + 1));
classical_mtx->copy_from(gen_mtx<Vec>(50, 50, 1));
auto load_balance_mtx_d = Mtx::create(cuda);
auto classical_mtx_d = Mtx::create(cuda);
load_balance_mtx_d->copy_from(load_balance_mtx.get());
classical_mtx_d->copy_from(classical_mtx.get());

load_balance_mtx_d->set_strategy(automatical);
classical_mtx_d->set_strategy(automatical);

EXPECT_EQ("load_balance", load_balance_mtx_d->get_strategy()->get_name());
EXPECT_EQ("classical", classical_mtx_d->get_strategy()->get_name());
ASSERT_NE(load_balance_mtx_d->get_strategy().get(),
classical_mtx_d->get_strategy().get());
}


} // namespace
2 changes: 1 addition & 1 deletion doc/helpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function(ginkgo_doc_gen name in pdf mainpage-in)
${doxygen_base_input}
)
# pick some markdown files we want as pages
set(doxygen_markdown_files "../../INSTALL.md ../../TESTING.md ../../BENCHMARKING.md")
set(doxygen_markdown_files "../../INSTALL.md ../../TESTING.md ../../BENCHMARKING.md ../../CONTRIBUTING.md")
ginkgo_to_string(doxygen_base_input_str ${doxygen_base_input} )
ginkgo_to_string(doxygen_dev_input_str ${doxygen_dev_input} )
ginkgo_to_string(doxygen_image_path_str ${doxygen_image_path} )
Expand Down
25 changes: 25 additions & 0 deletions hip/test/matrix/csr_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,4 +650,29 @@ TEST_F(Csr, SortUnsortedMatrixIsEquivalentToRef)
}


TEST_F(Csr, OneAutomaticalWorksWithDifferentMatrices)
{
auto automatical = std::make_shared<Mtx::automatical>(hip);
auto row_len_limit = std::max(automatical->nvidia_row_len_limit,
automatical->amd_row_len_limit);
auto load_balance_mtx = Mtx::create(ref);
auto classical_mtx = Mtx::create(ref);
load_balance_mtx->copy_from(
gen_mtx<Vec>(1, row_len_limit + 1000, row_len_limit + 1));
classical_mtx->copy_from(gen_mtx<Vec>(50, 50, 1));
auto load_balance_mtx_d = Mtx::create(hip);
auto classical_mtx_d = Mtx::create(hip);
load_balance_mtx_d->copy_from(load_balance_mtx.get());
classical_mtx_d->copy_from(classical_mtx.get());

load_balance_mtx_d->set_strategy(automatical);
classical_mtx_d->set_strategy(automatical);

EXPECT_EQ("load_balance", load_balance_mtx_d->get_strategy()->get_name());
EXPECT_EQ("classical", classical_mtx_d->get_strategy()->get_name());
ASSERT_NE(load_balance_mtx_d->get_strategy().get(),
classical_mtx_d->get_strategy().get());
}


} // namespace
Loading