Skip to content

Commit

Permalink
templated batch_cg for all optimization parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
phu0ngng committed Jun 6, 2023
1 parent 68f9565 commit 992173a
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 84 deletions.
116 changes: 64 additions & 52 deletions dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,33 +76,75 @@ class KernelCaller {
: exec_{exec}, opts_{opts}
{}

template <typename StopType, const int simd_len, const bool vecs_shared_all,
const bool sg_kernels_only, typename PrecType, typename LogType,
typename BatchMatrixType>
__dpct_inline__ void launch_apply_kernel(
const gko::kernels::batch_cg::StorageConfig& sconf, LogType& logger,
PrecType& prec, const BatchMatrixType a,
const ValueType* const __restrict__ b_values,
ValueType* const __restrict__ x_values,
ValueType* const __restrict__ workspace,
const size_t& shared_size) const
{
auto nrows = a.num_rows;
int group_size =
(exec_->get_queue()->get_device())
.get_info<sycl::info::device::max_work_group_size>();
if (group_size > 2 * nrows) group_size = get_larger_power(nrows);

const dim3 block(group_size);
const dim3 grid(a.num_batch);

auto max_iters = opts_.max_its;
auto res_tol = opts_.residual_tol;

(exec_->get_queue())->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
sycl::access::target::local>
slm_values(sycl::range<1>(shared_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
simd_len)]] [[intel::kernel_args_restrict]] {
auto batch_id = item_ct1.get_group_linear_id();
const auto a_global_entry =
gko::batch::batch_entry(a, batch_id);
const ValueType* const b_global_entry =
gko::batch::batch_entry_ptr(b_values, 1, nrows,
batch_id);
ValueType* const x_global_entry =
gko::batch::batch_entry_ptr(x_values, 1, nrows,
batch_id);
apply_kernel<StopType, vecs_shared_all, sg_kernels_only>(
sconf, max_iters, res_tol, logger, prec, a_global_entry,
b_global_entry, x_global_entry, nrows, a.num_nnz,
static_cast<ValueType*>(slm_values.get_pointer()),
item_ct1, workspace);
});
});
}

template <typename BatchMatrixType, typename PrecType, typename StopType,
typename LogType>
void call_kernel(LogType logger, const BatchMatrixType& a, PrecType prec,
const gko::batch_dense::UniformBatch<const ValueType>& b,
const gko::batch_dense::UniformBatch<ValueType>& x) const
{
using real_type = typename gko::remove_complex<ValueType>;
const size_type num_batches = a.num_batch;
const auto nrows = a.num_rows;
const auto nrhs = b.num_rhs;
GKO_ASSERT(nrhs == 1);

auto device = exec_->get_queue()->get_device();
int group_size =
device.get_info<sycl::info::device::max_work_group_size>();
if (group_size > 2 * nrows) group_size = get_larger_power(nrows);
constexpr int subgroup_size = config::warp_size;

const dim3 block(group_size);
const dim3 grid(num_batches);

size_type shmem_per_blk =
device.get_info<sycl::info::device::local_mem_size>() -
(group_size + 3) * sizeof(ValueType) - 2 * sizeof(real_type);
const auto matrix_size = a.get_entry_storage();
const int shared_gap =
nrows; // TODO: check if it is neccessary to align
const int shared_gap = nrows;
const size_type prec_size =
PrecType::dynamic_work_size(shared_gap, a.num_nnz) *
sizeof(ValueType);
Expand All @@ -113,56 +155,26 @@ class KernelCaller {
sconf.n_shared * shared_gap +
(sconf.prec_shared ? prec_size : 0) / sizeof(ValueType);
auto workspace = gko::Array<ValueType>(
exec_, sconf.gmem_stride_bytes * num_batches / sizeof(ValueType));
assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0);
// std::cout << "HERE: " << sconf.n_shared << " " <<
// sconf.prec_shared << std::endl;
exec_, sconf.gmem_stride_bytes * a.num_batch / sizeof(ValueType));

ValueType* const workspace_data = workspace.get_data();
auto b_values = b.values;
auto x_values = x.values;
auto max_iters = opts_.max_its;
auto res_tol = opts_.residual_tol;

auto launch_apply_kernel = [&]<const int SIMDLEN,
const bool vecs_shared_all,
const bool sg_kernel_only>() {
(exec_->get_queue())->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
sycl::access::target::local>
slm_values(sycl::range<1>(shared_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
SIMDLEN)]] {
auto batch_id = item_ct1.get_group_linear_id();
const auto a_global_entry =
gko::batch::batch_entry(a, batch_id);
const ValueType* const b_global_entry =
gko::batch::batch_entry_ptr(b_values, 1, nrows,
batch_id);
ValueType* const x_global_entry =
gko::batch::batch_entry_ptr(x_values, 1, nrows,
batch_id);
apply_kernel<StopType, vecs_shared_all, sg_kernel_only>(
sconf, max_iters, res_tol, logger, prec,
a_global_entry, b_global_entry, x_global_entry,
nrows, a.num_nnz,
static_cast<ValueType*>(slm_values.get_pointer()),
item_ct1, workspace_data);
});
});
};

if (nrows <= 32)
launch_apply_kernel<16, 1, 1>();
launch_apply_kernel<StopType, 16, 1, 1>(
sconf, logger, prec, a, b.values, x.values, workspace_data,
shared_size);
else if (nrows <= 256 && sconf.n_global == 0)
launch_apply_kernel<32, 1, 1>();
launch_apply_kernel<StopType, 32, 1, 1>(
sconf, logger, prec, a, b.values, x.values, workspace_data,
shared_size);
else if (sconf.n_global == 0)
launch_apply_kernel<32, 1, 0>();
launch_apply_kernel<StopType, 32, 1, 0>(
sconf, logger, prec, a, b.values, x.values, workspace_data,
shared_size);
else
launch_apply_kernel<32, 0, 0>();
launch_apply_kernel<StopType, 32, 0, 0>(
sconf, logger, prec, a, b.values, x.values, workspace_data,
shared_size);
}

private:
Expand Down
62 changes: 30 additions & 32 deletions dpcpp/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

namespace {

template <const bool sg_kernels_only, typename BatchMatrixType_entry,
typename PrecType, typename ValueType>
template <const bool sg_kernels_only, typename PrecType, typename ValueType,
typename BatchMatrixType>
__dpct_inline__ void initialize(
const int num_rows, const BatchMatrixType_entry& A_global_entry,
const ValueType* const b_global_entry,
const ValueType* const x_global_entry, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, const PrecType& prec_shared,
ValueType* const z_shared_entry, ValueType& rho_old,
ValueType* const p_shared_entry, gko::remove_complex<ValueType>& rhs_norms,
sycl::nd_item<3> item_ct1)
const int num_rows, const BatchMatrixType& A_global_entry,
const ValueType* const __restrict__ b_global_entry,
const ValueType* const __restrict__ x_global_entry,
ValueType* const __restrict__ x_shared_entry,
ValueType* const __restrict__ r_shared_entry, const PrecType& prec_shared,
ValueType* const __restrict__ z_shared_entry, ValueType& rho_old,
ValueType* const __restrict__ p_shared_entry,
gko::remove_complex<ValueType>& rhs_norms, sycl::nd_item<3> item_ct1)
{
auto sg = item_ct1.get_sub_group();
auto group = item_ct1.get_group();
Expand Down Expand Up @@ -91,12 +92,11 @@ __dpct_inline__ void initialize(


template <typename ValueType>
__dpct_inline__ void update_p(const int num_rows,
const ValueType& rho_new_shared_entry,
const ValueType& rho_old_shared_entry,
const ValueType* const z_shared_entry,
ValueType* const p_shared_entry,
sycl::nd_item<3> item_ct1)
__dpct_inline__ void update_p(
const int num_rows, const ValueType& rho_new_shared_entry,
const ValueType& rho_old_shared_entry,
const ValueType* const __restrict__ z_shared_entry,
ValueType* const __restrict__ p_shared_entry, sycl::nd_item<3> item_ct1)
{
const ValueType beta = rho_new_shared_entry / rho_old_shared_entry;
for (int li = item_ct1.get_local_linear_id(); li < num_rows;
Expand All @@ -106,14 +106,12 @@ __dpct_inline__ void update_p(const int num_rows,
}

template <const int sg_kernels_only, typename ValueType>
__dpct_inline__ void update_x_and_r(const int num_rows,
const ValueType rho_old_shared_entry,
const ValueType* const p_shared_entry,
const ValueType* const Ap_shared_entry,
ValueType& alpha_shared_entry,
ValueType* const x_shared_entry,
ValueType* const r_shared_entry,
sycl::nd_item<3> item_ct1)
__dpct_inline__ void update_x_and_r(
const int num_rows, const ValueType rho_old_shared_entry,
const ValueType* const __restrict__ p_shared_entry,
const ValueType* const __restrict__ Ap_shared_entry,
ValueType& alpha_shared_entry, ValueType* const __restrict__ x_shared_entry,
ValueType* const __restrict__ r_shared_entry, sycl::nd_item<3> item_ct1)
{
auto group = item_ct1.get_group();
auto sg = item_ct1.get_sub_group();
Expand Down Expand Up @@ -143,15 +141,15 @@ __dpct_inline__ void update_x_and_r(const int num_rows,
template <typename StopType, const bool vecs_shared_all,
const bool sg_kernels_only, typename PrecType, typename LogType,
typename BatchMatrixType, typename ValueType>
void apply_kernel(const gko::kernels::batch_cg::StorageConfig sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType A_global_entry,
const ValueType* const __restrict__ b_global_entry,
ValueType* const __restrict__ x_global_entry,
const size_type nrows, const size_type nnz,
ValueType* slm_values, sycl::nd_item<3> item_ct1,
ValueType* const __restrict__ workspace = nullptr)
__dpct_inline__ void apply_kernel(
const gko::kernels::batch_cg::StorageConfig sconf, const int max_iter,
const gko::remove_complex<ValueType> tol, LogType logger,
PrecType prec_shared, const BatchMatrixType& A_global_entry,
const ValueType* const __restrict__ b_global_entry,
ValueType* const __restrict__ x_global_entry, const size_type nrows,
const size_type nnz, ValueType* const __restrict__ slm_values,
sycl::nd_item<3> item_ct1,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;

Expand Down

0 comments on commit 992173a

Please sign in to comment.