Skip to content

Run-time checks for CUDA and cuBLAS versions #1938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) {
}
return false;
#else
// Check run-time CUDA version
if (transformer_engine::cuda::cudart_version() < 12040) {
if (getenv("NVTE_UBDEBUG")) {
printf(
"TransformerEngine does not support multi-node NVLINK "
"since it is not being run with CUDA version >= 12.4.\n");
}
return false;
}

bool mnnvl_fabric_support = false;
CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
Expand Down
109 changes: 79 additions & 30 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return ret;
}

/* cuBLAS version number at run-time */
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}

} // namespace

namespace transformer_engine {
Expand Down Expand Up @@ -342,10 +349,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&fastAccuMode, sizeof(fastAccuMode)));

// Scaling factors.
#if CUDA_VERSION >= 12080
#if CUBLAS_VERSION >= 120800
cublasLtMatmulMatrixScale_t scaling_mode_a;
cublasLtMatmulMatrixScale_t scaling_mode_b;
#endif
#endif // CUBLAS_VERSION >= 120800
if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv;
Expand All @@ -355,10 +362,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
#if CUBLAS_VERSION >= 120800
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#endif // CUBLAS_VERSION >= 120800
} else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
Expand All @@ -371,17 +382,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) {
if (cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
#else
NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120800
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUDA_VERSION >= 12090
#if CUBLAS_VERSION >= 120900
NVTE_CHECK(cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version());
float *A_scale_inverse = reinterpret_cast<float *>(param.A_scale_inv);
float *B_scale_inverse = reinterpret_cast<float *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
Expand All @@ -400,34 +418,40 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
? CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
: CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
NVTE_ERROR("FP8 block scaling requires CUDA 12.9+");
#endif // CUDA_VERSION >= 12090
#endif // CUDA_VERSION >= 12080
NVTE_ERROR("FP8 block scaling requires cuBLAS 12.9+, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif // CUBLAS_VERSION >= 120900
} else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + ".");
}

#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode_a, sizeof(scaling_mode_a)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
#endif
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_MODE,
&scaling_mode_b, sizeof(scaling_mode_b)));
}
#endif // CUBLAS_VERSION >= 120800
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
C = nullptr;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUDA_VERSION >= 12080
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
#endif
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
}
#endif // CUBLAS_VERSION >= 120800
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
Expand Down Expand Up @@ -495,9 +519,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));

if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
if (counter != nullptr) {
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
Expand All @@ -515,8 +554,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER, &counter,
sizeof(counter)));
}
}
#endif
}

NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
Expand Down Expand Up @@ -600,15 +639,25 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_atomic_gemm);
using namespace transformer_engine;

int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020 && cudart_version < 13000,
"Cuda version >=12.2 and <13.0 is required for atomic gemm.");
NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
CUBLAS_VERSION);
#endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
cublas_version());

using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
Expand Down
15 changes: 14 additions & 1 deletion transformer_engine/common/util/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id)

bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the
// NOTE: This needs to be guarded at compile-time and run-time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
if (cudart_version() < 12010) {
return false;
}
static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) {
Expand Down Expand Up @@ -197,6 +200,16 @@ const std::string &include_directory(bool required) {
return path;
}

int cudart_version() {
auto get_version = []() -> int {
int version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version));
return version;
};
static int version = get_version();
return version;
}

} // namespace cuda

} // namespace transformer_engine
6 changes: 6 additions & 0 deletions transformer_engine/common/util/cuda_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ bool supports_multicast(int device_id = -1);
*/
const std::string &include_directory(bool required = false);

/* \brief CUDA Runtime version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
int cudart_version();

} // namespace cuda

} // namespace transformer_engine
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/util/rtc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);
const bool compile_ptx = sm_arch_ != compile_sm_arch;

// Compilation flags
std::vector<std::string> opts = {
Expand Down