diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index e940207de5cfbb..7d1cea906f2f6a 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -182,8 +182,12 @@ __global__ void vectorized_templated_elementwise_kernel( loader_t loader, storer_t storer) { using traits = function_traits; + // Remainder chunk uses slower scheme than vectorized: + // assign it in the first threadblock to enable scheduler + // to start is as soon as possible, by inverting mapping + // of loop chunks to threadblocks. int remaining = - N - vectorized_templated_config::block_work_size() * blockIdx.x; + N - vectorized_templated_config::block_work_size() * (gridDim.x - blockIdx.x - 1); if (remaining < vectorized_templated_config::block_work_size()) { // if this block handles // the reminder, @@ -197,12 +201,10 @@ __global__ void vectorized_templated_elementwise_kernel( out_calc_t, loader_t, storer_t>(data, remaining, inp_calc, out_calc, loader, storer); - templated_elementwise_kernel_helper(f, policy); + templated_elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access - templated_elementwise_kernel_helper< - vectorized_templated_config::thread_work_size()>( - f, + auto policy = memory::policies::vectorized_templated< vectorized_templated_config::thread_work_size(), vectorized_templated_config::num_threads(), @@ -210,7 +212,10 @@ __global__ void vectorized_templated_elementwise_kernel( vec_size, array_t, OutputType, - InputTypes...>(data)); + InputTypes...>(data); + templated_elementwise_kernel_helper( + f, policy); } } diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 4bf9860239e0d1..440273dcfec0dd 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -68,6 +68,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { #ifdef USE_ROCM template < + bool reverted_idx=false, int thread_work_size = thread_work_size(), typename func_t, typename policy_t> @@ -79,6 +80,8 @@ __device__ inline void templated_elementwise_kernel_helper( using args_t = typename traits::ArgsTuple; int idx = blockIdx.x; + if constexpr (reverted_idx) + idx = gridDim.x - blockIdx.x - 1; return_t results[thread_work_size]; args_t args[thread_work_size];