Skip to content

Commit

Permalink
fixed batch_cg and batch_bicgstab, both kernels should work with larg…
Browse files Browse the repository at this point in the history
…e sizes now
  • Loading branch information
phu0ngng authored and pratikvn committed Mar 26, 2023
1 parent bb4e59f commit a5474cc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
14 changes: 4 additions & 10 deletions dpcpp/solver/batch_bicgstab_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class KernelCaller {
auto group_size =
device.get_info<sycl::info::device::max_work_group_size>();
constexpr int subgroup_size = config::warp_size;
GKO_ASSERT(nrows >= 2 * subgroup_size);
GKO_ASSERT(group_size >= 2 * subgroup_size);

const dim3 block(group_size);
const dim3 grid(num_batches);
Expand All @@ -114,15 +114,9 @@ class KernelCaller {
gko::kernels::batch_bicgstab::compute_shared_storage<PrecType,
ValueType>(
shmem_per_blk, shared_gap, a.num_nnz, b.num_rhs);
std::cout << "HERE " << sconf.n_shared << " " << sconf.n_global << " "
<< sconf.prec_shared << std::endl;
const size_t shared_size =
sconf.n_shared * shared_gap * sizeof(ValueType) +
(sconf.prec_shared ? prec_size : 0);
std::cout << "slm_size: " << slm_size << ",shared_size: " << shared_size
<< std::endl;
std::cout << "Workspace size: " << sconf.gmem_stride_bytes * num_batches
<< std::endl;
sconf.n_shared * shared_gap +
(sconf.prec_shared ? prec_size : 0) / sizeof(ValueType);
auto workspace = gko::array<ValueType>(
exec_, sconf.gmem_stride_bytes * num_batches / sizeof(ValueType));
assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0);
Expand All @@ -132,7 +126,7 @@ class KernelCaller {
auto x_values = x.values;
auto max_iters = opts_.max_its;
auto res_tol = opts_.residual_tol;
const int local_accessor_size = shared_size + 5 * sizeof(ValueType);
const int local_accessor_size = shared_size + 5;

(exec_->get_queue())->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
Expand Down
22 changes: 12 additions & 10 deletions dpcpp/solver/batch_cg_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class KernelCaller {
auto group_size =
device.get_info<sycl::info::device::max_work_group_size>();
constexpr int subgroup_size = config::warp_size;
GKO_ASSERT(nrows >= 2 * subgroup_size);
GKO_ASSERT(group_size >= 2 * subgroup_size);

const dim3 block(group_size);
const dim3 grid(num_batches);
Expand All @@ -100,9 +100,9 @@ class KernelCaller {
device.get_info<sycl::info::device::local_mem_size>();
const auto matrix_size = a.get_entry_storage();
size_type shmem_per_blk =
slm_size - matrix_size - 3 * sizeof(ValueType) -
2 * sizeof(real_type); // reserve 5 for intermediate rho-s, norms,
// and alp
slm_size - 3 * sizeof(ValueType) -
2 * sizeof(real_type); // reserve 3 for intermediate rho-s, norms,
// and alpha
if (shmem_per_blk < 0) shmem_per_blk = 0;
const int shared_gap =
nrows; // TODO: check if it is neccessary to align
Expand All @@ -113,8 +113,8 @@ class KernelCaller {
gko::kernels::batch_cg::compute_shared_storage<PrecType, ValueType>(
shmem_per_blk, shared_gap, a.num_nnz, b.num_rhs);
const size_t shared_size =
sconf.n_shared * shared_gap * sizeof(ValueType) +
(sconf.prec_shared ? prec_size : 0);
sconf.n_shared * shared_gap +
(sconf.prec_shared ? prec_size : 0) / sizeof(ValueType);
auto workspace = gko::Array<ValueType>(
exec_, sconf.gmem_stride_bytes * num_batches / sizeof(ValueType));
assert(sconf.gmem_stride_bytes % sizeof(ValueType) == 0);
Expand All @@ -124,7 +124,7 @@ class KernelCaller {
auto x_values = x.values;
auto max_iters = opts_.max_its;
auto res_tol = opts_.residual_tol;
const int local_accessor_size = shared_size + 3 * sizeof(ValueType);
const int local_accessor_size = shared_size + 3;

(exec_->get_queue())->submit([&](sycl::handler& cgh) {
sycl::accessor<ValueType, 1, sycl::access_mode::read_write,
Expand All @@ -149,12 +149,14 @@ class KernelCaller {
gko::batch::batch_entry_ptr(x_values, 1, nrows,
batch_id);

ValueType* const slm_values_ptr =
slm_values.get_pointer();
real_type* const slm_reals_ptr =
slm_reals.get_pointer();
apply_kernel<StopType>(
sconf, max_iters, res_tol, logger, prec,
a_global_entry, b_global_entry, x_global_entry,
nrows, a.num_nnz,
static_cast<ValueType*>(slm_values.get_pointer()),
static_cast<real_type*>(slm_reals.get_pointer()),
nrows, a.num_nnz, slm_values_ptr, slm_reals_ptr,
item_ct1, workspace_data);
});
});
Expand Down

0 comments on commit a5474cc

Please sign in to comment.