diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index b8fb51304e4b0b..5a02d199ed6b03 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && - num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads, + num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(), "BlockReduceSum requires all warps be active"); const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); dim3 grid = unique_indices.numel(); diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3e67f5ad5bfbeb..72374095baac29 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -86,7 +86,7 @@ void renormRows(Tensor& t) { TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; const int64_t maxThreads = std::min( - props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); + props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads()); int warp_size = at::cuda::warp_size(); dim3 grid(rows < numSM * 4 ? rows : numSM * 4); diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index b848ed5748e5c8..be158584cedb8d 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -209,7 +209,7 @@ void handle_fused_mode( constexpr int num_threads = size / 2; int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 && - num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, ""); + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), ""); const auto memsize = (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); compute_mode diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index df757a11761bba..fc44b0f4a9da0c 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -14,7 +14,17 @@ constexpr int kCUDABlockReduceNumThreads = 512; // of which reduces C10_WARP_SIZE elements. So, at most // C10_WARP_SIZE**2 elements can be reduced at a time. // NOTE: This is >= the max block size on current hardware anyway (1024). -constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; +// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions, +// and kCUDABlockReduceMaxThreads is a host-side variable. +#ifdef USE_ROCM +static int kCUDABlockReduceMaxThreads() { + return at::cuda::warp_size() * at::cuda::warp_size(); +} +#else +constexpr int kCUDABlockReduceMaxThreads() { + return C10_WARP_SIZE * C10_WARP_SIZE; +} +#endif // Sums `val` across all threads in a warp. //