Skip to content

Commit

Permalink
batch_cg passed all tests when using slm
Browse files Browse the repository at this point in the history
  • Loading branch information
phu0ngng committed Mar 9, 2023
1 parent 1207ff0 commit 997ef8c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 42 deletions.
31 changes: 12 additions & 19 deletions dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,52 +96,43 @@ class KernelCaller {
const dim3 block(group_size);
const dim3 grid(num_batches);

const size_type slm_size =
size_type slm_size =
device.get_info<sycl::info::device::local_mem_size>();
const auto matrix_storage = a.get_entry_storage();
const auto matrix_size = a.get_entry_storage();
size_type shmem_per_blk =
slm_size - 3 * sizeof(ValueType) - 2 * sizeof(real_type) -
matrix_storage; // reserve 5 for intermediate rho-s and norms
slm_size - matrix_size - 3 * sizeof(ValueType) -
2 * sizeof(real_type); // reserve 5 for intermediate rho-s, norms,
// and alp
if (shmem_per_blk < 0) shmem_per_blk = 0;
const int shared_gap = ((nrows - 1) / 8 + 1) *
8; // TODO: check if it is neccessary to align
const int shared_gap =
nrows; // TODO: check if it is neccessary to align
const size_type prec_size =
PrecType::dynamic_work_size(shared_gap, a.num_nnz) *
sizeof(ValueType);
const auto sconf =
gko::kernels::batch_cg::compute_shared_storage<PrecType, ValueType>(
shmem_per_blk, shared_gap, a.num_nnz,
b.num_rhs); // TODO: Make it works with shared_pc
shmem_per_blk, shared_gap, a.num_nnz, b.num_rhs);
const size_t shared_size =
sconf.n_shared * shared_gap * sizeof(ValueType) +
(sconf.prec_shared ? prec_size : 0);
auto workspace = gko::Array<ValueType>(
exec_, sconf.gmem_stride_bytes * num_batches / sizeof(ValueType));
assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0);

/*
const int workspace_size =
gko::kernels::batch_cg::local_memory_requirement<ValueType>(nrows,
nrhs) +
PrecType::dynamic_work_size(nrows, a.num_nnz) *
sizeof(ValueType); auto workspace = gko::Array<ValueType>( exec_,
workspace_size * num_batch_entries / sizeof(ValueType));
*/
ValueType* const workspace_data = workspace.get_data();
auto b_values = b.values;
auto x_values = x.values;
auto max_iters = opts_.max_its;
auto res_tol = opts_.residual_tol;
const int local_accessor_size = shared_size + 3 * sizeof(ValueType);
//(slm_size - 2 * sizeof(real_type)) / sizeof(ValueType);

(exec_->get_queue())->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
sycl::access::target::local>
slm_values(sycl::range<1>(local_accessor_size), cgh);
sycl::accessor<real_type, 1, sycl::access_mode::read_write,
sycl::access::target::local>
slm_no_cmplx(sycl::range<1>(2), cgh);
slm_reals(sycl::range<1>(2), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block),
Expand All @@ -161,7 +152,9 @@ class KernelCaller {
apply_kernel<StopType>(
sconf, max_iters, res_tol, logger, prec,
a_global_entry, b_global_entry, x_global_entry,
nrows, a.num_nnz, slm_values, slm_no_cmplx,
nrows, a.num_nnz,
static_cast<ValueType*>(slm_values.get_pointer()),
static_cast<real_type*>(slm_reals.get_pointer()),
item_ct1, workspace_data);
});
});
Expand Down
40 changes: 17 additions & 23 deletions dpcpp/solver/batch_cg_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,17 @@ __dpct_inline__ void update_x_and_r(const int num_rows,

template <typename StopType, typename PrecType, typename LogType,
typename BatchMatrixType, typename ValueType>
void apply_kernel(
const gko::kernels::batch_cg::StorageConfig sconf, const int max_iter,
const gko::remove_complex<ValueType> tol, LogType logger,
PrecType prec_shared, const BatchMatrixType A_global_entry,
const ValueType* const __restrict__ b_global_entry,
ValueType* const __restrict__ x_global_entry, const size_type nrows,
const size_type nnz,
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
sycl::access::target::local>
local_mem_sh,
sycl::accessor<gko::remove_complex<ValueType>, 1,
sycl::access_mode::read_write, sycl::access::target::local>
all_norms_sh,
sycl::nd_item<3> item_ct1,
ValueType* const __restrict__ workspace = nullptr)
void apply_kernel(const gko::kernels::batch_cg::StorageConfig sconf,
const int max_iter, const gko::remove_complex<ValueType> tol,
LogType logger, PrecType prec_shared,
const BatchMatrixType A_global_entry,
const ValueType* const __restrict__ b_global_entry,
ValueType* const __restrict__ x_global_entry,
const size_type nrows, const size_type nnz,
ValueType* slm_values,
gko::remove_complex<ValueType>* slm_reals,
sycl::nd_item<3> item_ct1,
ValueType* const __restrict__ workspace = nullptr)
{
using real_type = typename gko::remove_complex<ValueType>;

Expand All @@ -159,24 +155,22 @@ void apply_kernel(

const auto ibatch = item_ct1.get_group_linear_id();

ValueType* rho_old_sh = &local_mem_sh[0];
ValueType* rho_new_sh = &local_mem_sh[1];
ValueType* alpha_sh = &local_mem_sh[2];
remove_complex<ValueType>* norms_rhs_sh = &all_norms_sh[0];
remove_complex<ValueType>* norms_res_sh = &all_norms_sh[1];
ValueType* rho_old_sh = &slm_values[0];
ValueType* rho_new_sh = &slm_values[1];
ValueType* alpha_sh = &slm_values[2];
remove_complex<ValueType>* norms_rhs_sh = &slm_reals[0];
remove_complex<ValueType>* norms_res_sh = &slm_reals[1];

const int gmem_offset =
ibatch * sconf.gmem_stride_bytes / sizeof(ValueType);
// extern __shared__ char local_mem_sh[];
ValueType* r_sh;
ValueType* z_sh;
ValueType* p_sh;
ValueType* Ap_sh;
ValueType* x_sh;
ValueType* prec_work_sh;
if (sconf.n_shared >= 1) {
// r_sh = reinterpret_cast<ValueType*>(local_mem_sh);
r_sh = &local_mem_sh[3];
r_sh = &slm_values[3];
} else {
r_sh = workspace + gmem_offset;
}
Expand Down

0 comments on commit 997ef8c

Please sign in to comment.