From e2de359256de6c6bd4c977fc96bad6893f40af75 Mon Sep 17 00:00:00 2001 From: Carlo Bertolli Date: Wed, 19 Mar 2025 11:06:57 -0500 Subject: [PATCH] Schedule remainder loop chunk in threadblock 0. Remainder loop chunk is not execute using input vectorized elementwise kernel, but uses standard unrolling, and it is expected to run slower than the rest of the chunk, on average. To prevent long tail, schedule remainder loop chunk as threablock 0 so it will be scheduled first and hopefully complete execution before other faster chunks. --- aten/src/ATen/native/cuda/CUDALoops.cuh | 17 +++++++++++------ aten/src/ATen/native/cuda/Loops.cuh | 3 +++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index e940207de5cfb..7d1cea906f2f6 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -182,8 +182,12 @@ __global__ void vectorized_templated_elementwise_kernel( loader_t loader, storer_t storer) { using traits = function_traits; + // Remainder chunk uses slower scheme than vectorized: + // assign it in the first threadblock to enable scheduler + // to start is as soon as possible, by inverting mapping + // of loop chunks to threadblocks. int remaining = - N - vectorized_templated_config::block_work_size() * blockIdx.x; + N - vectorized_templated_config::block_work_size() * (gridDim.x - blockIdx.x - 1); if (remaining < vectorized_templated_config::block_work_size()) { // if this block handles // the reminder, @@ -197,12 +201,10 @@ __global__ void vectorized_templated_elementwise_kernel( out_calc_t, loader_t, storer_t>(data, remaining, inp_calc, out_calc, loader, storer); - templated_elementwise_kernel_helper(f, policy); + templated_elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access - templated_elementwise_kernel_helper< - vectorized_templated_config::thread_work_size()>( - f, + auto policy = memory::policies::vectorized_templated< vectorized_templated_config::thread_work_size(), vectorized_templated_config::num_threads(), @@ -210,7 +212,10 @@ __global__ void vectorized_templated_elementwise_kernel( vec_size, array_t, OutputType, - InputTypes...>(data)); + InputTypes...>(data); + templated_elementwise_kernel_helper( + f, policy); } } diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 4bf9860239e0d..440273dcfec0d 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -68,6 +68,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { #ifdef USE_ROCM template < + bool reverted_idx=false, int thread_work_size = thread_work_size(), typename func_t, typename policy_t> @@ -79,6 +80,8 @@ __device__ inline void templated_elementwise_kernel_helper( using args_t = typename traits::ArgsTuple; int idx = blockIdx.x; + if constexpr (reverted_idx) + idx = gridDim.x - blockIdx.x - 1; return_t results[thread_work_size]; args_t args[thread_work_size];