Skip to content

Commit

Permalink
extract allocations from CSR OpenMP
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Apr 10, 2021
1 parent a62ed0d commit 25b94ce
Showing 1 changed file with 216 additions and 89 deletions.
305 changes: 216 additions & 89 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <algorithm>
#include <limits>
#include <numeric>
#include <utility>

Expand All @@ -51,6 +52,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "core/base/allocator.hpp"
#include "core/base/iterator_factory.hpp"
#include "core/base/utils.hpp"
#include "core/components/prefix_sum.hpp"
#include "core/matrix/csr_builder.hpp"
#include "omp/components/csr_spgeam.hpp"
Expand Down Expand Up @@ -130,100 +132,163 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL);


template <typename ValueType, typename IndexType>
void spgemm_insert_row(unordered_set<IndexType> &cols,
const matrix::Csr<ValueType, IndexType> *c,
size_type row)
{
auto row_ptrs = c->get_const_row_ptrs();
auto col_idxs = c->get_const_col_idxs();
cols.insert(col_idxs + row_ptrs[row], col_idxs + row_ptrs[row + 1]);
}
namespace {


template <typename ValueType, typename IndexType>
void spgemm_insert_row2(unordered_set<IndexType> &cols,
const matrix::Csr<ValueType, IndexType> *a,
const matrix::Csr<ValueType, IndexType> *b,
size_type row)
{
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto b_row_ptrs = b->get_const_row_ptrs();
auto b_col_idxs = b->get_const_col_idxs();
for (size_type a_nz = a_row_ptrs[row];
a_nz < size_type(a_row_ptrs[row + 1]); ++a_nz) {
auto a_col = a_col_idxs[a_nz];
auto b_row = a_col;
cols.insert(b_col_idxs + b_row_ptrs[b_row],
b_col_idxs + b_row_ptrs[b_row + 1]);
}
}
struct col_heap_element {
using value_type = ValueType;
using index_type = IndexType;
using matrix_type = matrix::Csr<ValueType, IndexType>;

IndexType idx;
IndexType end;
IndexType col;

ValueType val() const { return zero<ValueType>(); }

col_heap_element(IndexType idx, IndexType end, IndexType col, ValueType)
: idx{idx}, end{end}, col{col}
{}
};


template <typename ValueType, typename IndexType>
void spgemm_accumulate_row(map<IndexType, ValueType> &cols,
const matrix::Csr<ValueType, IndexType> *c,
ValueType scale, size_type row)
struct val_heap_element {
using value_type = ValueType;
using index_type = IndexType;
using matrix_type = matrix::Csr<ValueType, IndexType>;

IndexType idx;
IndexType end;
IndexType col;
ValueType val_;

ValueType val() const { return val_; }

val_heap_element(IndexType idx, IndexType end, IndexType col, ValueType val)
: idx{idx}, end{end}, col{col}, val_{val}
{}
};


template <typename HeapElement>
void sift_down(HeapElement *heap, typename HeapElement::index_type idx,
typename HeapElement::index_type size)
{
auto row_ptrs = c->get_const_row_ptrs();
auto col_idxs = c->get_const_col_idxs();
auto vals = c->get_const_values();
for (size_type c_nz = row_ptrs[row]; c_nz < size_type(row_ptrs[row + 1]);
++c_nz) {
auto c_col = col_idxs[c_nz];
auto c_val = vals[c_nz];
cols[c_col] += scale * c_val;
auto curcol = heap[idx].col;
while (idx * 2 + 1 < size) {
auto lchild = idx * 2 + 1;
auto rchild = min(lchild + 1, size - 1);
auto lcol = heap[lchild].col;
auto rcol = heap[rchild].col;
auto mincol = min(lcol, rcol);
if (mincol >= curcol) {
break;
}
auto minchild = lcol == mincol ? lchild : rchild;
std::swap(heap[minchild], heap[idx]);
idx = minchild;
}
}


template <typename ValueType, typename IndexType>
void spgemm_accumulate_row2(map<IndexType, ValueType> &cols,
const matrix::Csr<ValueType, IndexType> *a,
const matrix::Csr<ValueType, IndexType> *b,
ValueType scale, size_type row)
template <typename HeapElement, typename InitCallback, typename StepCallback,
typename ColCallback>
auto spgemm_multiway_merge(size_type row,
const typename HeapElement::matrix_type *a,
const typename HeapElement::matrix_type *b,
HeapElement *heap, InitCallback init_cb,
StepCallback step_cb, ColCallback col_cb)
-> decltype(init_cb(0))
{
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto a_cols = a->get_const_col_idxs();
auto a_vals = a->get_const_values();
auto b_row_ptrs = b->get_const_row_ptrs();
auto b_col_idxs = b->get_const_col_idxs();
auto b_cols = b->get_const_col_idxs();
auto b_vals = b->get_const_values();
for (size_type a_nz = a_row_ptrs[row];
a_nz < size_type(a_row_ptrs[row + 1]); ++a_nz) {
auto a_col = a_col_idxs[a_nz];
auto a_val = a_vals[a_nz];
auto b_row = a_col;
for (size_type b_nz = b_row_ptrs[b_row];
b_nz < size_type(b_row_ptrs[b_row + 1]); ++b_nz) {
auto b_col = b_col_idxs[b_nz];
auto b_val = b_vals[b_nz];
cols[b_col] += scale * a_val * b_val;
auto a_begin = a_row_ptrs[row];
auto a_end = a_row_ptrs[row + 1];

using index_type = typename HeapElement::index_type;
constexpr auto sentinel = std::numeric_limits<index_type>::max();

auto state = init_cb(row);

// initialize the heap
for (auto a_nz = a_begin; a_nz < a_end; ++a_nz) {
auto b_row = a_cols[a_nz];
auto b_begin = b_row_ptrs[b_row];
auto b_end = b_row_ptrs[b_row + 1];
heap[a_nz] = {b_begin, b_end,
checked_load(b_cols, b_begin, b_end, sentinel),
a_vals[a_nz]};
}

if (a_begin != a_end) {
// make heap:
auto a_size = a_end - a_begin;
for (auto i = (a_size - 2) / 2; i >= 0; --i) {
sift_down(heap + a_begin, i, a_size);
}
auto &top = heap[a_begin];
auto &bot = heap[a_end - 1];
auto col = top.col;

while (top.col != sentinel) {
step_cb(b_vals[top.idx] * top.val(), top.col, state);
// move to the next element
top.idx++;
top.col = checked_load(b_cols, top.idx, top.end, sentinel);
// restore heap property
// pop_heap swaps top and bot, we need to prevent that
// so that we do a simple sift_down instead
sift_down(heap + a_begin, index_type{}, a_size);
if (top.col != col) {
col_cb(col, state);
}
col = top.col;
}
}

return state;
}


} // namespace


template <typename ValueType, typename IndexType>
void spgemm(std::shared_ptr<const OmpExecutor> exec,
const matrix::Csr<ValueType, IndexType> *a,
const matrix::Csr<ValueType, IndexType> *b,
matrix::Csr<ValueType, IndexType> *c)
{
auto num_rows = a->get_size()[0];

// first sweep: count nnz for each row
auto c_row_ptrs = c->get_row_ptrs();

unordered_set<IndexType> local_col_idxs(exec);
#pragma omp parallel for firstprivate(local_col_idxs)
Array<col_heap_element<ValueType, IndexType>> col_heap_array(
exec, a->get_num_stored_elements());

auto col_heap = col_heap_array.get_data();

// first sweep: count nnz for each row
#pragma omp parallel for
for (size_type a_row = 0; a_row < num_rows; ++a_row) {
local_col_idxs.clear();
spgemm_insert_row2(local_col_idxs, a, b, a_row);
c_row_ptrs[a_row] = local_col_idxs.size();
c_row_ptrs[a_row] = spgemm_multiway_merge(
a_row, a, b, col_heap, [](size_type) { return IndexType{}; },
[](ValueType, IndexType, IndexType &) {},
[](IndexType, IndexType &nnz) { nnz++; });
}

col_heap_array.clear();

Array<val_heap_element<ValueType, IndexType>> heap_array(
exec, a->get_num_stored_elements());

auto heap = heap_array.get_data();

// build row pointers
components::prefix_sum(exec, c_row_ptrs, num_rows + 1);

Expand All @@ -237,18 +302,21 @@ void spgemm(std::shared_ptr<const OmpExecutor> exec,
auto c_col_idxs = c_col_idxs_array.get_data();
auto c_vals = c_vals_array.get_data();

map<IndexType, ValueType> local_row_nzs(exec);
#pragma omp parallel for firstprivate(local_row_nzs)
#pragma omp parallel for
for (size_type a_row = 0; a_row < num_rows; ++a_row) {
local_row_nzs.clear();
spgemm_accumulate_row2(local_row_nzs, a, b, one<ValueType>(), a_row);
// store result
auto c_nz = c_row_ptrs[a_row];
for (auto pair : local_row_nzs) {
c_col_idxs[c_nz] = pair.first;
c_vals[c_nz] = pair.second;
++c_nz;
}
spgemm_multiway_merge(
a_row, a, b, heap,
[&](size_type row) {
return std::make_pair(zero<ValueType>(), c_row_ptrs[row]);
},
[](ValueType val, IndexType,
std::pair<ValueType, IndexType> &state) { state.first += val; },
[&](IndexType col, std::pair<ValueType, IndexType> &state) {
c_col_idxs[state.second] = col;
c_vals[state.second] = state.first;
state.first = zero<ValueType>();
state.second++;
});
}
}

Expand All @@ -267,17 +335,41 @@ void advanced_spgemm(std::shared_ptr<const OmpExecutor> exec,
auto num_rows = a->get_size()[0];
auto valpha = alpha->at(0, 0);
auto vbeta = beta->at(0, 0);
constexpr auto sentinel = std::numeric_limits<IndexType>::max();

// first sweep: count nnz for each row
auto c_row_ptrs = c->get_row_ptrs();
auto d_row_ptrs = d->get_const_row_ptrs();
auto d_cols = d->get_const_col_idxs();
auto d_vals = d->get_const_values();

Array<val_heap_element<ValueType, IndexType>> heap_array(
exec, a->get_num_stored_elements());

unordered_set<IndexType> local_col_idxs(exec);
#pragma omp parallel for firstprivate(local_col_idxs)
auto heap = heap_array.get_data();
auto col_heap =
reinterpret_cast<col_heap_element<ValueType, IndexType> *>(heap);

// first sweep: count nnz for each row
#pragma omp parallel for
for (size_type a_row = 0; a_row < num_rows; ++a_row) {
local_col_idxs.clear();
spgemm_insert_row(local_col_idxs, d, a_row);
spgemm_insert_row2(local_col_idxs, a, b, a_row);
c_row_ptrs[a_row] = local_col_idxs.size();
auto d_nz = d_row_ptrs[a_row];
auto d_end = d_row_ptrs[a_row + 1];
auto d_col = checked_load(d_cols, d_nz, d_end, sentinel);
c_row_ptrs[a_row] = spgemm_multiway_merge(
a_row, a, b, col_heap, [](size_type row) { return IndexType{}; },
[](ValueType, IndexType, IndexType &) {},
[&](IndexType col, IndexType &nnz) {
// skip smaller elements from d
while (d_col <= col) {
d_nz++;
nnz += d_col != col;
d_col = checked_load(d_cols, d_nz, d_end, sentinel);
}
nnz++;
});
// handle the remaining columns from d
c_row_ptrs[a_row] += d_end - d_nz;
}

// build row pointers
Expand All @@ -293,18 +385,53 @@ void advanced_spgemm(std::shared_ptr<const OmpExecutor> exec,
auto c_col_idxs = c_col_idxs_array.get_data();
auto c_vals = c_vals_array.get_data();

map<IndexType, ValueType> local_row_nzs(exec);
#pragma omp parallel for firstprivate(local_row_nzs)
#pragma omp parallel for
for (size_type a_row = 0; a_row < num_rows; ++a_row) {
local_row_nzs.clear();
spgemm_accumulate_row(local_row_nzs, d, vbeta, a_row);
spgemm_accumulate_row2(local_row_nzs, a, b, valpha, a_row);
// store result
auto c_nz = c_row_ptrs[a_row];
for (auto pair : local_row_nzs) {
c_col_idxs[c_nz] = pair.first;
c_vals[c_nz] = pair.second;
++c_nz;
auto d_nz = d_row_ptrs[a_row];
auto d_end = d_row_ptrs[a_row + 1];
auto d_col = checked_load(d_cols, d_nz, d_end, sentinel);
auto d_val = checked_load(d_vals, d_nz, d_end, zero<ValueType>());
auto c_nz =
spgemm_multiway_merge(
a_row, a, b, heap,
[&](size_type row) {
return std::make_pair(zero<ValueType>(), c_row_ptrs[row]);
},
[](ValueType val, IndexType,
std::pair<ValueType, IndexType> &state) {
state.first += val;
},
[&](IndexType col, std::pair<ValueType, IndexType> &state) {
// handle smaller elements from d
ValueType part_d_val{};
while (d_col <= col) {
if (d_col == col) {
part_d_val = d_val;
} else {
c_col_idxs[state.second] = d_col;
c_vals[state.second] = vbeta * d_val;
state.second++;
}
d_nz++;
d_col = checked_load(d_cols, d_nz, d_end, sentinel);
d_val = checked_load(d_vals, d_nz, d_end,
zero<ValueType>());
}
c_col_idxs[state.second] = col;
c_vals[state.second] =
vbeta * part_d_val + valpha * state.first;
state.first = zero<ValueType>();
state.second++;
})
.second;
// handle remaining elements from d
while (d_col < sentinel) {
c_col_idxs[c_nz] = d_col;
c_vals[c_nz] = vbeta * d_val;
c_nz++;
d_nz++;
d_col = checked_load(d_cols, d_nz, d_end, sentinel);
d_val = checked_load(d_vals, d_nz, d_end, zero<ValueType>());
}
}
}
Expand Down

0 comments on commit 25b94ce

Please sign in to comment.