Skip to content

Schedule remainder loop chunk in threadblock 0. #1985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/2.5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,12 @@ __global__ void vectorized_templated_elementwise_kernel(
loader_t loader,
storer_t storer) {
using traits = function_traits<func_t>;
// Remainder chunk uses slower scheme than vectorized:
// assign it in the first threadblock to enable scheduler
// to start is as soon as possible, by inverting mapping
// of loop chunks to threadblocks.
int remaining =
N - vectorized_templated_config::block_work_size() * blockIdx.x;
N - vectorized_templated_config::block_work_size() * (gridDim.x - blockIdx.x - 1);
if (remaining <
vectorized_templated_config::block_work_size()) { // if this block handles
// the reminder,
Expand All @@ -197,20 +201,21 @@ __global__ void vectorized_templated_elementwise_kernel(
out_calc_t,
loader_t,
storer_t>(data, remaining, inp_calc, out_calc, loader, storer);
templated_elementwise_kernel_helper(f, policy);
templated_elementwise_kernel_helper</*reverted_idx=*/true>(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
templated_elementwise_kernel_helper<
vectorized_templated_config::thread_work_size()>(
f,
auto policy =
memory::policies::vectorized_templated<
vectorized_templated_config::thread_work_size(),
vectorized_templated_config::num_threads(),
vectorized_templated_config::block_work_size(),
vec_size,
array_t,
OutputType,
InputTypes...>(data));
InputTypes...>(data);
templated_elementwise_kernel_helper</*reverted_idx=*/true,
vectorized_templated_config::thread_work_size()>(
f, policy);
}
}

Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {

#ifdef USE_ROCM
template <
bool reverted_idx=false,
int thread_work_size = thread_work_size(),
typename func_t,
typename policy_t>
Expand All @@ -79,6 +80,8 @@ __device__ inline void templated_elementwise_kernel_helper(
using args_t = typename traits::ArgsTuple;

int idx = blockIdx.x;
if constexpr (reverted_idx)
idx = gridDim.x - blockIdx.x - 1;

return_t results[thread_work_size];
args_t args[thread_work_size];
Expand Down