From fd2a0432ae459fdabb6d3e5651ff4b918ab947fa Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Wed, 25 Jun 2025 15:10:51 -0700 Subject: [PATCH] [rocm7.0_internal_testing] Prevent static initialization of at::cuda::warp_size() (#2293) Fixes SWDEV-540240, SWDEV-540309, SWDEV-539989 ``` ... ``` https://github.com/ROCm/pytorch/commit/80cca7006d94df97ee932fd5903ed20c08c2eb34 created a static global variable that used `at::cuda::warp_size()` to initialize its value, which needs GPUs to be visible to query device properties. However, GPUs are not present on CPU-only build systems. Convert static variable into a static function, thus preventing static initialization. http://rocm-ci.amd.com/job/pyt_whl_docker_mainline/1461/artifact/build_artifacts.txt/*view*/ Ran microbenchmark to confirm basic functionality: ``` root@ubb4-rack-22:/var/lib/jenkins/pytorch-micro-benchmarking# python3 micro_benchmarking_pytorch.py --network resnet50 INFO: running forward and backward for warmup. INFO: running the benchmark.. OK: finished running benchmark.. --------------------SUMMARY-------------------------- Microbenchmark for network : resnet50 Num devices: 1 Dtype: FP32 Mini batch size [img] : 64 Time per mini-batch : 0.10158218145370483 Throughput [img/sec] : 630.0317544289736= ``` --- aten/src/ATen/native/cuda/Embedding.cu | 2 +- aten/src/ATen/native/cuda/MultinomialKernel.cu | 2 +- aten/src/ATen/native/cuda/TensorModeKernel.cu | 2 +- aten/src/ATen/native/cuda/block_reduce.cuh | 12 +++++++++++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index b8fb51304e4b0b..5a02d199ed6b03 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && - num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads, + num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(), "BlockReduceSum requires all warps be active"); const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); dim3 grid = unique_indices.numel(); diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3e67f5ad5bfbeb..72374095baac29 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -86,7 +86,7 @@ void renormRows(Tensor& t) { TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; const int64_t maxThreads = std::min( - props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); + props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads()); int warp_size = at::cuda::warp_size(); dim3 grid(rows < numSM * 4 ? rows : numSM * 4); diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index b848ed5748e5c8..be158584cedb8d 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -209,7 +209,7 @@ void handle_fused_mode( constexpr int num_threads = size / 2; int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 && - num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, ""); + num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), ""); const auto memsize = (sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int)); compute_mode diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index df757a11761bba..fc44b0f4a9da0c 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -14,7 +14,17 @@ constexpr int kCUDABlockReduceNumThreads = 512; // of which reduces C10_WARP_SIZE elements. So, at most // C10_WARP_SIZE**2 elements can be reduced at a time. // NOTE: This is >= the max block size on current hardware anyway (1024). -constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; +// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions, +// and kCUDABlockReduceMaxThreads is a host-side variable. +#ifdef USE_ROCM +static int kCUDABlockReduceMaxThreads() { + return at::cuda::warp_size() * at::cuda::warp_size(); +} +#else +constexpr int kCUDABlockReduceMaxThreads() { + return C10_WARP_SIZE * C10_WARP_SIZE; +} +#endif // Sums `val` across all threads in a warp. //