diff --git a/build_tools/utils.py b/build_tools/utils.py index 0dc5e36898..5004f824f3 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -237,9 +237,9 @@ def cuda_archs() -> str: version = cuda_version() if os.getenv("NVTE_CUDA_ARCHS") is None: if version >= (13, 0): - os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + os.environ["NVTE_CUDA_ARCHS"] = "100a" elif version >= (12, 8): - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + os.environ["NVTE_CUDA_ARCHS"] = "100a" else: os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" return os.getenv("NVTE_CUDA_ARCHS") diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index afc80cba43..ae4d7e42c5 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -6,7 +6,7 @@ cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + set(CMAKE_CUDA_ARCHITECTURES 100) else () set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) endif() @@ -26,7 +26,7 @@ enable_testing() include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) - execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" + execute_process(COMMAND bash -c "pip3 show transformer-engine | grep 'Editable project location' | cut -d ' ' -f 4 | tr -d '\n'" OUTPUT_VARIABLE TE_LIB_PATH) endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0b0e615495..a7a31a19f5 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,29 +3,30 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_causal_softmax.cu - test_swizzle.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu + # test_cast_mxfp8.cu + test_cast_nvfp4.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_causal_softmax.cu + # test_swizzle.cu ../test_common.cu) find_package(OpenMP REQUIRED) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index bea9887369..cf0b044f7e 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -36,95 +36,34 @@ enum ActivationType { SReLU }; -template -void scale_block(const ProcessingMethod processing_method, +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, const InputType* input, const InputType* grad, - OutputType* output_c, - float* dbias, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - float amax = 0.0f; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - dbias[j] += elt; - if (isinf(elt) || isnan(elt)) { - continue; - } - amax = std::max(amax, std::abs(elt)); - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - output_c[idx] = static_cast(elt * scale_reciprocal); - } - } -} - -template -void compute_ref_x1(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_c, - fp8e8m0* output_scales, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; std::vector output_dbias_fp32(cols, 0); #pragma omp parallel proc_bind(spread) { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + std::vector thread_dbias(cols, 0); #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { @@ -133,24 +72,83 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; - scale_block( - processing_method, input, grad, output_c, thread_dbias.data(), - output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + thread_dbias[j] += elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } } } } @@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, } } -template -void compute_ref_x2(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_rowwise, - OutputType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, - rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - processing_method, input, grad, output_colwise, scales_colwise, output_dbias, - rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const bool rowwise, const bool colwise, @@ -261,28 +237,46 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output_c.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output_c.data(), 0); break; } } @@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x1(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c.get(), - ref_output_scales.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + compute_ref(processing_method, + OP, + rowwise, + colwise, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride, + scales_stride); const uint8_t * const gpu_scales_ptr = rowwise ? output_c.rowwise_cpu_scale_inv_ptr() : output_c.columnwise_cpu_scale_inv_ptr(); - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales = 0; + compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const size_t block_size_rows, const size_t block_size_cols, @@ -401,28 +412,46 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output.data(), 0); break; } } @@ -431,32 +460,54 @@ void performTest_x2(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x2(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c_rowwise.get(), - ref_output_c_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); + compute_ref(processing_method, + OP, + true, + true, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales_rowwise = 0; + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + size_t mismatches_scales_colwise = 0; + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); - - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_scales_rowwise); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_scales_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -528,26 +579,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam transformer_engine::DType, InputsFillCase>> {}; -#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ -} - -#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ -} - TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { @@ -581,35 +612,48 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { const bool colwise = block_size.first != 1; if (processing_method == ProcessingMethod::CAST_ACT) { // Forward activations - ACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } else { - DACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &dgelu; break; + case ActivationType::SiLU: OP = &dsilu; break; + case ActivationType::ReLU: OP = &drelu; break; + case ActivationType::QGeLU: OP = &dqgelu; break; + case ActivationType::SReLU: OP = &dsrelu; break; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 2b22942f84..553a3c44f8 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -18,107 +18,32 @@ using namespace test; namespace { -template -void scale_block(const IType* grad, +template +void compute_ref(const IType* grad, const IType* input, - OType* output, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t scale_idx_gate, - float& thread_amax, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - - float block_amax = 0.0f; - float block_amax_gate = 0.0f; - const size_t stride = cols * 2; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - float gated_amax_act = 0; - float gated_amax_gate = 0; - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - gated_amax_act = abs(after_dsilu); - gated_amax_gate = abs(after_dgate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - gated_amax_act = abs(after_silu); - } - - if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } - if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * - Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - float scale_reciprocal_gate = 1; - if constexpr (IS_DGATED) { - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * - Quantized_Limits::max_reciprocal()); - scale_reciprocal_gate = exp2f_rcp(biased_exponent); - output_scales[scale_idx_gate] = biased_exponent; - } - - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); - output[i * stride + cols + j] = static_cast(after_dgate * - scale_reciprocal_gate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - output[i * cols + j] = static_cast(after_silu * scale_reciprocal); - } - - } - } - thread_amax = std::max(thread_amax, block_amax); - thread_amax = std::max(thread_amax, block_amax_gate); -} - -template -void compute_ref_x1(const IType* grad, - const IType* input, - OType* output, - fp8e8m0* output_scales, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + float& ref_amax, + const bool IS_DGATED, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise, + const bool is_rowwise, + const bool is_colwise) { + constexpr size_t tile_size_Y = 32; + constexpr size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; - float amax = 0; #pragma omp parallel reduction(max: amax) proc_bind(spread) { - float thread_amax = 0; + // Buffers to cache intermediate computations + std::vector cache_buffer_act(tile_size_Y * tile_size_X); + std::vector cache_buffer_gate(tile_size_Y * tile_size_X); + float thread_amax = 0.0f; #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { const size_t tile_Y = t / tiles_num_X; @@ -126,26 +51,124 @@ void compute_ref_x1(const IType* grad, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; - const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + - cols / block_size_X; - scale_block( - grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, - thread_amax, i_min, i_max, j_min, j_max, cols); + const size_t stride = cols * 2; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y); + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(cols, tile_offset_X + tile_size_X); + + // Compute and cache activations for the entire tile + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + + if (IS_DGATED) { + const float x = silu_elt; + const float s = sigmoid(x); + const float act_x = x * s; + const float dact_x = x * s * (1 - s) + s; + + const float grad_elt = static_cast(grad[i * cols + j]); + float after_dsilu = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_dsilu = static_cast(static_cast(after_dsilu)); + after_dgate = static_cast(static_cast(after_dgate)); + + cache_buffer_act[cached_idx] = after_dsilu; + cache_buffer_gate[cached_idx] = after_dgate; + thread_amax = std::max(thread_amax, std::abs(after_dsilu)); + thread_amax = std::max(thread_amax, std::abs(after_dgate)); + } else { + float after_silu = silu(silu_elt) * gate_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_silu = static_cast(static_cast(after_silu)); + + cache_buffer_act[cached_idx] = after_silu; + thread_amax = std::max(thread_amax, std::abs(after_silu)); + } + } + } + + if (is_rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; + output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_rowwise[i * stride + j] = static_cast(after_act); + output_rowwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_rowwise[i * cols + j] = static_cast(after_act); + } + } + } + } + + if (is_colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + const int scale_idx_gate = scale_idx_act + cols; + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + output_scales_colwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_colwise[i * stride + j] = static_cast(after_act); + output_colwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_colwise[i * cols + j] = static_cast(after_act); + } + } } } } @@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad, ref_amax = amax; } -template -void compute_ref_x2(const IType* grad, - const IType* input, - OType* output_rowwise, - OType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad, * OR * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -198,12 +202,6 @@ void performTest_x1(const size_t rows, const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); NVTE_CHECK(rowwise || colwise); - // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; - // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; - // std::cout << "blocks_Y: " << blocks_Y << std::endl; - // std::cout << "blocks_X: " << blocks_X << std::endl; - // std::cout << "scales_stride: " << scales_stride << std::endl; - Tensor grad("grad", std::vector{ rows, cols }, itype); Tensor input("input", std::vector{ rows, cols * 2 }, itype); @@ -229,12 +227,12 @@ void performTest_x1(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -245,30 +243,48 @@ void performTest_x1(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x1(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_scales.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride, + scales_stride, + rowwise, + colwise); + + size_t mismatches_scales = 0; + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts); } /** @@ -278,12 +294,13 @@ void performTest_x1(const size_t rows, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -325,12 +342,12 @@ void performTest_x2(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -341,30 +358,49 @@ void performTest_x2(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x2(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output_rowwise.get(), - ref_output_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise, + true, + true); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; + + size_t mismatches_scales_rowwise = 0; + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + size_t mismatches_scales_colwise = 0; + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; auto [atol, rtol] = getTolerances(otype); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); } std::vector> matrix_sizes = { @@ -393,9 +429,9 @@ std::vector input_scenarios = { // InputsFillCase::maxNorm_to_inf }; -std::vector is_dgated_op = { - true, - false +std::vector is_bwd_op = { + false, + true }; } // namespace @@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, if (block_size.first == 1 || block_size.second == 1) { - if (IS_DGATED) { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } else { - if (IS_DGATED) { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } ); ); @@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), - ::testing::ValuesIn(is_dgated_op)), + ::testing::ValuesIn(is_bwd_op)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + @@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" + - (std::get<5>(info.param) ? "DGATED" : "GATED"); + (std::get<5>(info.param) ? "BWD" : "FWD"); return name; }); diff --git a/tests/cpp/operator/test_cast_nvfp4.cu b/tests/cpp/operator/test_cast_nvfp4.cu new file mode 100644 index 0000000000..60764fa873 --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4.cu @@ -0,0 +1,468 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ScalingType { + ROWWISE = 0, + BIDIMENSIONAL = 1 +}; + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const float truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const float truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +void compute_ref(const bool rowwise, + const bool colwise, + float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output_rowwise_nvfp4, + OutputType* output_colwise_mxfp8, + fp8e4m3* scales_rowwise_nvfp4, + fp8e8m0* scales_colwise_mxfp8, + const float nvfp4_second_stage_scale, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + const size_t tile_size_Y = 32; + const size_t tile_size_X = 16; + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = 1.0f / nvfp4_second_stage_scale; + + #pragma omp parallel proc_bind(spread) + { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_enc / static_cast(S_dec_b_fp8); + + const int scale_idx = i * scales_stride_rowwise + tile_X; + scales_rowwise_nvfp4[scale_idx] = S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = (i - i_min) * tile_size_X + (j - j_min); + const int cache_idx_y = (i - i_min) * tile_size_X + (j + 1 - j_min); + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output_rowwise_nvfp4[idx_pair] = casted_to_e2m1_pair; + + const float2 truncated_pair = cvt_fp4x2_to_float2(casted_to_e2m1_pair); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = tile_Y * scales_stride_colwise + j; + scales_colwise_mxfp8[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise_mxfp8[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + } + } +} + + +void compareResults_nvfp4(const std::string &name, const Tensor &test, + const void *ref, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true) { + const std::string direction = "rowwise"; + + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const int idx = i * cols + j; + + const __nv_fp4x2_storage_t* test_raw_storage = reinterpret_cast(&test_data[idx/2]); + const __nv_fp4x2_storage_t* ref_raw_storage = reinterpret_cast(&ref_data[idx/2]); + + const __half2_raw test_data_pair_raw = __nv_cvt_fp4x2_to_halfraw2(*test_raw_storage, __NV_E2M1); + const __half2_raw ref_data_pair_raw = __nv_cvt_fp4x2_to_halfraw2(*ref_raw_storage, __NV_E2M1); + + const __half2 test_data_pair(test_data_pair_raw); + const __half2 ref_data_pair(ref_data_pair_raw); + + for (int k = 0; k < 2; ++k) { + const double t = static_cast(k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = static_cast(k == 0 ? ref_data_pair.x : ref_data_pair.y); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " + << " (" << std::to_string(idx + k) << "): " + << t << " vs " << r; + } + } + } + } +} + +/** + * Scaling along selected dimensions + * Produces sets of output data: + * 1) NVFP4 Scaled rows + E4M3 row-wise scaling factors + * AND + * 2) MXFP8 Scaled columns + E8M0 column-wise scaling factors + */ + +template +void performTest(float (*OP)(const float), + const std::vector& shape, + const bool colwise, + InputsFillCase fill_case) { + using namespace test; + + constexpr bool rowwise_true = true; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, rowwise_true, colwise, NVTE_FWD_NVFP4_BWD_MXFP8_SCALING); + + std::unique_ptr ref_output_nvfp4; + std::unique_ptr ref_output_mxfp8; + std::unique_ptr ref_scales_nvfp4; + std::unique_ptr ref_scales_mxfp8; + + ref_output_nvfp4 = std::make_unique(rows * cols / 2); + ref_scales_nvfp4 = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + + if (colwise) { + ref_output_mxfp8 = std::make_unique(rows * cols); + ref_scales_mxfp8 = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + } + + fillCase(&input, fill_case); + setRandomScale(&output); + + auto nvte_quantize_operation = &nvte_quantize; + if (OP == &gelu) { nvte_quantize_operation = &nvte_gelu; } + else if (OP == &silu) { nvte_quantize_operation = &nvte_silu; } + else if (OP == &relu) { nvte_quantize_operation = &nvte_relu; } + else if (OP == &qgelu) { nvte_quantize_operation = &nvte_qgelu; } + else if (OP == &srelu) { nvte_quantize_operation = &nvte_srelu; } + + nvte_quantize_operation(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref(rowwise_true, + colwise, + OP, + input.rowwise_cpu_dptr(), + ref_output_nvfp4.get(), + ref_output_mxfp8.get(), + ref_scales_nvfp4.get(), + ref_scales_mxfp8.get(), + output.scale(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const double atol = 0.05; + const double rtol = 0.1; + compareResults_nvfp4("rowwise_nvfp4", output, ref_output_nvfp4.get(), rows, cols, atol, rtol); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("rowwise_scales_E4M3", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_nvfp4.get(), + unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, + scale_mismatches_num); + + if (colwise) { + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales = 0; + compare_scaling_factors("colwise_scales_E8M0", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_mxfp8.get(), + unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_limit = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + + compareResults("colwise_mxfp8", output, ref_output_mxfp8.get(), false, atol, rtol, true, mismatches_elts_limit); + } +} + +std::vector> matrix_sizes = { + {1, 32}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {511, 6144}, + {8192, 128}, + {2048, 160}, + {577, 1632}, + {1024}, + {8, 32, 1024}, + {16, 8, 4, 512}, +}; + +std::vector scaling_case = { + ScalingType::ROWWISE, // Row-wise NVFP4 {1, 16} + ScalingType::BIDIMENSIONAL // {32, 16} Row-wise NVFP4 AND Column-wise MXFP8 +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + // ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + +} // namespace + +class FusedCastNVFP4TestSuite : public ::testing::TestWithParam + , + ScalingType, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase>> {}; + +TEST_P(FusedCastNVFP4TestSuite, TestFusedCastNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ActivationType Act_type = std::get<0>(GetParam()); + const auto tensor_dims = std::get<1>(GetParam()); + const ScalingType scaling_case = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + + const bool colwise = (scaling_case == ScalingType::BIDIMENSIONAL); + + // Skip tests with colwise scaling, if the input tensor is 1D + if (tensor_dims.size() < 2 && colwise) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(OP, tensor_dims, colwise, fill_case); + ); + ); +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "CAST_ONLY"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +std::string to_string(const ScalingType scaling_type) { + switch (scaling_type) { + case ScalingType::ROWWISE: return "ROWWISE_NVFP4_1x16"; + case ScalingType::BIDIMENSIONAL: return "BIDIMENSIONAL_32x16"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(scaling_case), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + to_string(std::get<2>(info.param)) + + "X" + test::typeName(std::get<3>(info.param)) + + "X" + test::typeName(std::get<4>(info.param)) + + "X" + test::caseName(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 0f64d7c01b..74ae64275f 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -133,6 +133,40 @@ std::pair get_scales(const NVTEShape& shape, ret.type_size_bits = typeToNumBits(DType::kFloat32); return {ret, ret}; } + if (scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 2 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + auto block_alignment = std::vector{128ul, 4ul}; + { + auto alignment = block_alignment[1]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; + alignment = block_alignment[0]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(16)), alignment) * alignment; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto alignment = block_alignment[0]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; + alignment = block_alignment[1]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat8E4M3; + ret_colwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } if (scaling_mode == NVTE_MXFP8_1D_SCALING) { std::vector shape_vec; for (size_t i = 0; i < shape.ndim; ++i) { @@ -257,7 +291,7 @@ Tensor::Tensor(const std::string& name, columnwise_shape_vec.emplace_back(shape.data[i]); } } else { - // Same shape for MX + // Same shape for MX and NVFP4 for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } @@ -283,54 +317,65 @@ Tensor::Tensor(const std::string& name, std::fill_n(cpu_data_columnwise_.get(), total_size, 0); } } - tensor_.set_rowwise_data(dptr_rowwise, type, shape); - tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); - - if (isFp8Type(type)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) - cudaMemset(amax, 0, sizeof(float)); - cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) - cudaMemset(scale, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(0); - scale_cpu_data_ = std::make_shared(0); - tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); - tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); - cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) - if (rowwise) { - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); - } - if (columnwise) { - tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); - } - } else { - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode()); - auto rowwise_scale_size = rowwise_scale_meta.bytes(); - auto columnwise_scale_size = colwise_scale_meta.bytes(); - auto scale_shape = rowwise_scale_meta.shape; - auto columnwise_scale_shape = colwise_scale_meta.shape; - if (rowwise) { - cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) - cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); - rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - auto scale_dtype = rowwise_scale_meta.type; - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); - } - if (columnwise) { - cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) - cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); - columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - auto scale_dtype = colwise_scale_meta.type; - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); + + if (scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + NVTE_CHECK(isFp8Type(type) && "Invalid data type!"); + tensor_.set_rowwise_data(dptr_rowwise, DType::kFloat4E2M1, shape); + tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); + + // Used for NVFP4 second stage scaling + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); + auto rowwise_scale_size = rowwise_scale_meta.bytes(); + auto columnwise_scale_size = colwise_scale_meta.bytes(); + auto scale_shape = rowwise_scale_meta.shape; + auto columnwise_scale_shape = colwise_scale_meta.shape; + if (rowwise) { + cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); + rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); + } + if (columnwise) { + cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) + cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); + columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); + } + } else { + tensor_.set_rowwise_data(dptr_rowwise, type, shape); + tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); + if (isFp8Type(type) || isFp4Type(type)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) + cudaMemset(amax, 0, sizeof(float)); + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + amax_cpu_data_ = std::make_shared(0); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + if (rowwise) { + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + if (columnwise) { + tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } } } } @@ -346,13 +391,19 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { + const DType colwise_type = (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) + ? DType::kFloat8E4M3 + : tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); + tensor_.get_columnwise_data().data_ptr, + colwise_size, + cudaMemcpyDeviceToHost); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), @@ -364,8 +415,7 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -394,15 +444,15 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -419,9 +469,10 @@ void Tensor::from_cpu() const { } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING)) { *scale_cpu_data_ = scale; from_cpu(); } @@ -429,7 +480,7 @@ void Tensor::set_scale(float scale) { } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { NVTE_CHECK(rowwise_scale_inv_cpu_data_); } @@ -437,8 +488,7 @@ void Tensor::set_scale_inv(float scale_inv) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { @@ -468,7 +518,8 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), @@ -523,10 +574,13 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, - double atol, double rtol, bool if_on_gpus) { + double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches_num = 0; + int first_mismatch_idx = -1; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); @@ -547,80 +601,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test, assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } } ); } template static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, - const size_t N, const double atol, const double rtol) { + const size_t N, const double atol, const double rtol, + size_t& mismatches) { int first_mismatch_idx = N; - bool is_mismatch_found = false; - #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ - reduction(min: first_mismatch_idx) proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - if (is_mismatch_found) { // early escape of the omp thread - continue; - } - - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); + #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) + { + size_t thread_mismatches = 0; + #pragma omp for schedule(static) + for (size_t i = 0; i < N; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && (data_type == DType::kFloat32); - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - if (assertion && i < first_mismatch_idx) { - first_mismatch_idx = i; - is_mismatch_found = true; + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + if (i < first_mismatch_idx) { + first_mismatch_idx = i; + } + thread_mismatches++; + } } + mismatches += thread_mismatches; } return first_mismatch_idx; } void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches = 0; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); - const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); - if (i != N) { + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); + if ((i != N) && (mismatches > tolerable_mismatches_limit)) { const double t = static_cast(test_data[i]); const double r = static_cast(ref_data[i]); std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(true) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + + GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } void compareResults(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { constexpr bool sequential = false; if constexpr (sequential) { - compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } else { - compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } } @@ -656,28 +732,94 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride) +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + const size_t N = row_blocks * col_blocks; + const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, + std::ceil(N * rel_tolerable_mismatches_limit)); + mismatches_num = 0; + std::vector mismatch_indices; + for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[idx]) << " vs " - << static_cast(ref[idx]) << " at index " << idx; + float t, r; + + bool assertion = false; + + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { + mismatches_num++; + mismatch_indices.push_back(idx); + } + if (mismatches_num > tolerable_mismatches_limit) { + std::cout << "Error in " << name << std::endl; + for (const int index : mismatch_indices) { + std::cout << "Mismatch at (" << index << "):" + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; + } + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "."; + } } } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N) -{ - for (int i = 0; i < N; i++) { - ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[i]) << " vs " - << static_cast(ref[i]) << " at index " << i; - } -} +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); std::pair getTolerances(const DType type) { switch(type) { @@ -825,6 +967,10 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + int32_t getDeviceComputeCapability() { cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); @@ -846,7 +992,8 @@ std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 3597c94d85..e778070af3 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -62,6 +62,8 @@ using fp8e5m2 = __nv_fp8_e5m2; using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template @@ -207,7 +209,9 @@ class Tensor { template T *columnwise_cpu_dptr() const { - NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + if (tensor_.scaling_mode() != NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + } NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); return reinterpret_cast(cpu_data_columnwise_.get()); } @@ -223,7 +227,9 @@ class Tensor { float scale() const { if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING), + "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -237,6 +243,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + // NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -250,6 +258,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + // NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E8M0, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -304,10 +314,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -constexpr size_t scale_tensor_alignment_X_rowwise = 4; -constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_colwise = 128; -constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_rowwise = 128; +constexpr size_t scale_tensor_alignment_Y_rowwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 4; +constexpr size_t scale_tensor_alignment_Y_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; @@ -413,7 +423,12 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); + if (biased_exp == 0) { + return 1.0f; + } + int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) + float fp32_val = *reinterpret_cast(&int_val); + return fp32_val; } inline float identity(const float x) { return x; } @@ -445,15 +460,19 @@ size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); @@ -476,6 +495,7 @@ const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; @@ -553,7 +573,7 @@ constexpr int32_t blackwellComputeCapability = 100; SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type MARKED TEST."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -572,7 +592,7 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 2."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ @@ -580,7 +600,7 @@ constexpr int32_t blackwellComputeCapability = 100; using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 3."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -605,5 +625,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 4."); \ + NVTE_ERROR("Invalid type."); \ } diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 4a39328623..cb15c52aca 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -7,9 +7,9 @@ cmake_minimum_required(VERSION 3.21) # Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + set(CMAKE_CUDA_ARCHITECTURES 100a) elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + set(CMAKE_CUDA_ARCHITECTURES 100a) else () set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) endif() @@ -189,11 +189,14 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) set_source_files_properties(activation/gelu.cu activation/relu.cu activation/swiglu.cu + util/cast.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-line-info") # Add source code mapping into the profiler output +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage # Number of parallel build jobs if(ENV{MAX_JOBS}) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67f173a4ab..6def14559e 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,7 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + quantize_helper(input, grad, output, dbias, workspace, nullptr, stream); } template @@ -46,8 +45,7 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + quantize_helper(input, grad, output, dbias, workspace, nullptr, stream); } template diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 192c915a84..d082f9bf4b 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -146,6 +146,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; + + // Dimension for the packed data types must reflect the number of individual U# values. uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next @@ -162,10 +164,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + (offset_elems * type_num_bits) / 8); - NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits; NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 22b448a001..08001671dc 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; // Alignment requirements for the Tensor Memory Accelerator (TMA) -constexpr int TMA_gmem_alignment = 16; // global memory address alignment +constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index e2d9ecc519..58382564cc 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -27,14 +27,8 @@ namespace transformer_engine { -template -__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { - return DIVUP(static_cast(N), static_cast(M)) * M; -} - namespace gated_kernels { -constexpr size_t ALIGNMENT_SIZE = 128; constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 512; @@ -76,18 +70,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float amax = 0; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ __align__(TMA_SHMEM_ALIGNMENT) char dshmem[]; constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -96,8 +86,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; - - // const size_t in_transaction_size = grad_mem + in_mem; constexpr size_t in_transaction_size = buff_elems * sizeof(IType); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -269,9 +257,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +namespace mxfp8_kernel { + +constexpr size_t CHUNK_DIM_Y = 64; +constexpr size_t CHUNK_DIM_X = 64; +constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; +constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + template + bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, @@ -284,43 +297,68 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; + constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. + constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; + const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; + const int gate_scale_idx_offset_colwise = cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ __align__(TMA_SHMEM_ALIGNMENT) char dshmem[]; - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -329,12 +367,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; - // const size_t in_transaction_size = grad_mem + in_mem; - const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); @@ -346,374 +381,493 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh; - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); } - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act_rowwise = - reinterpret_cast(&tensor_map_output_act_rowwise); - const uint64_t *TMAP_output_gate_rowwise = - reinterpret_cast(&tensor_map_output_gate_rowwise); - const uint64_t *TMAP_output_act_colwise = - reinterpret_cast(&tensor_map_output_act_colwise); - const uint64_t *TMAP_output_gate_colwise = - reinterpret_cast(&tensor_map_output_gate_colwise); + IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations + IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + const bool is_master_thread = (threadIdx.x == 0); // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - const bool is_master_thread = (threadIdx.x == 0); - - if (is_master_thread) { -// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); + initialize_barriers(mbar, is_master_thread); int parity = 0; - // Prefetch data of the first stage - if (is_master_thread) { - // Initiate bulk tensor copy - // Grad - if constexpr (IS_DGATED) { - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), - TMAP_grad_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - } - - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), - TMAP_in_act, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, + &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[0]); + copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } #pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; - const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; - if (next_it < ITERATIONS) { - if (is_master_thread) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, + global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, + global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], + &tensor_map_input_gate, global_offset_X, global_offset_Y, + shmem_buff_size, &mbar[next_stage], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[next_it]); + copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, + global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } } ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; - OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; - OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; - OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; - - // Assuming one iteration covers exactly 32 rows - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + ptx::mbarrier_wait_parity(&mbar[stage], parity); - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = + buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; +// 1. Read/Compute elements. Find MXFP8-block AMAX #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_colwise = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); + float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); + float after_act_elt; + float after_gate_elt; - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + after_act_elt = ActOP(act_elt, {}) * gate_elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } } - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; - } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; - } - if constexpr (USE_ROWWISE_SCALING) { + after_act_colwise[i] = after_act_elt; if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; - } + after_gate_colwise[i] = after_gate_elt; } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); + if constexpr (IS_DGATED) { + cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); + } } - } - if constexpr (USE_COLWISE_SCALING) { - __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } } } - } - - if constexpr (USE_COLWISE_SCALING) { - const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if constexpr (ONLY_COLWISE_SCALING) { + // Threads, whose id along Y-dim is 0, don't need to store to shared memory, + // as they manage the columwise reduction of the amax + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; } __syncthreads(); - if (tid_Y == 0) { + if (tid_Y_colwise == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_act >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + // All threads read the reduced amax (ACT) + thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + if constexpr (IS_DGATED) { + // Make sure the previous read of the ACT values has been completed, + // so the data are not rewritten + __syncthreads(); + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_gate >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + + // All threads read the reduced amax (GATE) + thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; } + } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; + const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; + + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx] = biased_exponent_act; + } + + float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + float block_scale_inverse_gate; + + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx_gate] = biased_exponent_gate; } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_gate_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dgate_reg[stage]); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_elt = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; + if constexpr (IS_DGATED) { + OType2 out_pair; + ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; + const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, + block_scale_inverse_gate}; + ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); + out_act_colwise_sh[shmem_offset_elt] = out_pair.x; + out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; + } else { + const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; + out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + } + + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + + Vec in_cached_act[WAVES]; + Vec in_cached_gate[WAVES]; + + float after_act_rowwise[SCALE_DIM_X]; + float after_gate_rowwise[SCALE_DIM_X]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; + IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); + } + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); + } + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], + in_cached_act[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); + if constexpr (IS_DGATED) { + const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], + in_cached_gate[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); + } + } + } + } } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); + if constexpr (!std::is_same_v) { + thread_amax_act = static_cast( + __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); + if constexpr (IS_DGATED) { + thread_amax_gate = static_cast( + __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); + } + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in_grad; + Vec in_act; + Vec in_gate; - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + in_act.load_from(&in_act_sh[shmem_offset_rowwise]); + in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); + } + +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + + float act_elt = static_cast(in_act.data.elt[e]); + float gate_elt = static_cast(in_gate.data.elt[e]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad.data.elt[e]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_act_rowwise[j] = after_act_elt; + after_gate_rowwise[j] = after_gate_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_rowwise[j] = after_act_elt; + } + + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + } + } - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; + const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx] = biased_exponent_act; } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, + block_scale_inverse_act}; + + float block_scale_inverse_gate; + ptx::floatx2 block_scale_inverse_2x_gate; + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; } +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out_act; + Vec out_gate; #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in_act; + OType2 &out_act_pair = reinterpret_cast(out_act.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_act.x = in_cached_act[w].data.elt[2 * e]; + in_act.y = in_cached_act[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_act.x = after_act_rowwise[j]; + in_act.y = after_act_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); + + if constexpr (IS_DGATED) { + IType2 in_gate; + OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_gate.x = in_cached_gate[w].data.elt[2 * e]; + in_gate.y = in_cached_gate[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_gate.x = after_gate_rowwise[j]; + in_gate.y = after_gate_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); + } + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); + } } - } // endif USE_COLWISE_SCALING + } - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; - // dGeLU - if constexpr (USE_ROWWISE_SCALING) { + if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_rowwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_rowwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); } } - - // dGeLU - if constexpr (USE_COLWISE_SCALING) { + if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_colwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_colwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); } } // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); } } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } + parity ^= 1; + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel template @@ -771,17 +925,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem); // + mbar_mem; + const size_t shmem_size = + grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem); cudaFuncSetAttribute( cast_fp8_gated_kernel, @@ -809,16 +962,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } - // TODO: Make more general - const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; - const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + ScalingType scaling_type; + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + scaling_type = ScalingType::COLWISE; + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + scaling_type = ScalingType::BIDIMENSIONAL; + } const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + + const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + + constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; + constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; + const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) + ? THREADS_PER_CHUNK_COLWISE + : THREADS_PER_CHUNK_NON_COLWISE; + + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -828,94 +999,122 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out e8m0_t *const scales_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, cols, input_type_bit_size); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_act_colwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, - typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, - typeToNumBits(gated_input.dtype())); - - if (USE_ROWWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - } - - if (USE_COLWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, typeToNumBits(output->dtype())); - } - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } - - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; - - const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; - - cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); - - cast_mxfp8_gated_kernel - <<>>( + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + }); // NOLINT(*) + ); // NOLINT(*) } template @@ -1064,21 +1263,21 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); Tensor *output_tensor = convertNVTETensorCheck(output); - if (is_supported_by_CC_100()) { - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); - } else { - cast_gated(gated_input_tensor, output_tensor, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } + // if (is_supported_by_CC_100()) { + // quantize_gated(grad_tensor, gated_input_tensor, + // output_tensor, stream); + // } else { + // if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { + // if constexpr (IS_DGATED) { + // cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + // } else { + // cast_gated(gated_input_tensor, output_tensor, stream); + // } + // } else { + // // MX scaling + // NVTE_ERROR("Not supported by the Arch < 10.0"); + // } + // } } } // namespace detail diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 610cbf41fa..e3170a95b5 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -28,36 +28,25 @@ namespace transformer_engine { -constexpr size_t MXFP8_CHUNK_DIM_Y = 64; -constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; -constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; -constexpr size_t MXFP8_BUFFERS_NUM = 2; -constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); - -constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 -static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); +namespace mxfp8_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, @@ -67,201 +56,336 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - if (noop != nullptr && noop[0] == 1.0f) return; - } + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; - const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; - const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; - const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; - const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; - - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - - const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_rowwise[i].clear(); - } - } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_colwise[i] = 0; - } + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; } } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X; + const int tid_X_rowwise = threadIdx.x % THREADS_X; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + extern __shared__ __align__(TMA_SHMEM_ALIGNMENT) char dshmem[]; - constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; const bool is_master_thread = (threadIdx.x == 0); - float block_amax = 0; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - initialize_barriers(mbar, is_master_thread); + initialize_barriers(mbar, is_master_thread); int parity = 0; -#pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } } + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { - const int buff = iter % MXFP8_BUFFERS_NUM; - const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - - if (next_iter < MXFP8_ITERATIONS) { - const int next_buff = next_iter % MXFP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, - &mbar[next_iter], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } - } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - ptx::fence_proxy_async_shared_cta(); + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } - in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - } + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - float elt = static_cast(in.data.elt[j]); + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); + float act_in_elt = static_cast(act_in.data.elt[e]); elt *= OP(act_in_elt, {}); } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; - } - } - in_compute[j] = elt; - if constexpr (IS_ACT || IS_DACT) { + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); } @@ -269,196 +393,616 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } + in_compute_rowwise[j] = elt; } + } + } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + // 3. Scale elements #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; } - out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } + } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - float in_compute[SCALE_DIM_Y]; + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - float amax = 0; + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace mxfp8_kernel + +namespace nvfp4_kernel { + +using namespace ptx; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 +compute_decoding_scaling_factor(const float block_amax, const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, const float *const nvfp4_second_stage_scale_ptr, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && "Number of buffer rows must be greater or equal to the size of the columwise scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && "Number of buffer rows must be greater or equal to the number of rowwise processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + extern __shared__ __align__(TMA_SHMEM_ALIGNMENT) char dshmem[]; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = (nvfp4_second_stage_scale_ptr == nullptr) + ? 1.0f + : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + + // Initialize shared memory barrier with the number of threads participating in the barrier. + #pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + + #pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); + #pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { + #pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { - const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT) { + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { elt = OP(elt, {}); } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; - } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); } - in_compute[i] = elt; - if constexpr (IS_ACT || IS_DACT) { + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); + block_amax = fmaxf(block_amax, fabsf(elt)); } } else { // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); + block_amax = fmaxf(block_amax, fabsf(elt)); } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); - - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); - - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - out_colwise_sh[buff][i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); + in_compute_colwise[i] = elt; } } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (USE_ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); - } - if constexpr (USE_COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (global_scales_offset_X < cols) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + + // 3. Scale elements + #pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); + const float scaled_out = in * block_scale_inverse; - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - } - - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; - if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; + #pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + #pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { + #pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + + const int shmem_scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + + // 3. Scale elements + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; + #pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } } - __syncthreads(); - - if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + } - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); -#pragma unroll - for (int i = 0; i < Y; ++i) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } - } + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } - } + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); } - } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; - } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t& scales = *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); } } + float chunk_amax = 0.0f; if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + chunk_amax = reduce_max(thread_amax, warp_id); } if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); + atomicMaxFloat(amax_ptr, chunk_amax); } - destroy_barriers(mbar, is_master_thread); + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace nvfp4_kernel + constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; @@ -507,9 +1051,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; @@ -678,8 +1225,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; @@ -921,6 +1468,7 @@ template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); checkCuDriverContext(stream); @@ -936,16 +1484,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } CheckNoopTensor(*noop, "cast_noop"); - // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -958,6 +1514,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); @@ -972,58 +1537,231 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const dim3 block(MXFP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype(), OType, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(input.dtype())); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, - cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } - - cast_mxfp8_2D_kernel<<>>( + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_data_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_data_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +template +void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { + using namespace nvfp4_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = use_colwise_scaling + ? ScalingType::BIDIMENSIONAL + : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = use_colwise_scaling + ? output->columnwise_data.dtype + : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output_data_type, OType, + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, + noop_ptr, amax_ptr, nvfp4_second_stage_scale_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, + noop_ptr, amax_ptr, nvfp4_second_stage_scale_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + break; + } + ); // NOLINT(*) + ); // NOLINT(*) } namespace detail { @@ -1117,8 +1855,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons case NVTE_DELAYED_TENSOR_SCALING: { if (!IS_DBIAS && !IS_DACT) { if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 cast_fp8_1D(input, output, stream); } else { @@ -1127,9 +1865,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons } } else if (!IS_DBIAS && IS_DACT) { if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment) && - is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 (+dAct) cast_fp8_2D(input, act_input, output, dbias, workspace, stream); @@ -1186,29 +1924,29 @@ void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *no CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { - // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } + // if constexpr (IS_DBIAS) { + // NVTE_CHECK(dbias != nullptr); + // CheckOutputTensor(*dbias, "dbias"); + // } + // if constexpr (IS_DACT) { + // NVTE_CHECK(act_input != nullptr); + // CheckInputTensor(*act_input, "activation_input"); + // NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + // NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + // } + + // NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + // NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // // Supported by the Arch >= 10.0 + // if (is_supported_by_CC_100()) { + // fp8_quantize_arch_ge_100(input, act_input, noop, output, + // dbias, workspace, stream); + // } else { + // // Supported by the Arch < 10.0 + // fp8_quantize_arch_l_100(input, act_input, noop, output, + // dbias, workspace, stream); + // } } namespace detail { @@ -1241,73 +1979,77 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, stream); + // case NVTE_DELAYED_TENSOR_SCALING: { + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + // cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // } + // } else if (output_tensor->has_data()) { + // fp8_quantize( + // *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + // workspace_tensor, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8_quantize( + // *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + // workspace_tensor, stream); + // break; + // } + case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: { + nvfp4_quantize(*input_tensor, &noop_tensor, output_tensor, stream); break; } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + // "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; + // float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + // quantize_transpose_square_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + // "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; + // float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // bool rowwise_compact = quant_config_cpp + // ? quant_config_cpp->float8_block_scale_tensor_format == + // Float8BlockScaleTensorFormat::COMPACT + // : false; + // rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + // : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // bool columnwise_compact = quant_config_cpp + // ? quant_config_cpp->float8_block_scale_tensor_format == + // Float8BlockScaleTensorFormat::COMPACT + // : false; + // columnwise_option = columnwise_compact + // ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + // : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, + // output_tensor->columnwise_scale_inv, output_tensor->data, + // output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, stream); + // break; + // } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e716065abd..ae8df99916 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int transaction_size = shmem_buff_size; @@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; @@ -226,28 +226,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), nullptr, - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) + // const size_t N = product(input.data.shape); + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + // input.data.dtype, IType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + // output->data.dtype, OType, + + // constexpr int nvec = 32 / sizeof(OType); + // detail::DequantizeParam p; + // p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + // VectorizedUnaryKernelLauncher( + // reinterpret_cast(input.data.dptr), nullptr, + // reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + // stream);); // NOLINT(*) + // ); // NOLINT(*) } -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -268,67 +268,67 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); - - const size_t unpadded_scales_Y_rowwise = rows; - const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); - const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); - const size_t unpadded_scales_X_colwise = cols; - - const size_t scales_Y_rowwise = - DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * - scale_tensor_alignment_Y_rowwise; - const size_t scales_X_rowwise = - DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * - scale_tensor_alignment_X_rowwise; - const size_t scales_Y_colwise = - DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * - scale_tensor_alignment_Y_colwise; - const size_t scales_X_colwise = - DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * - scale_tensor_alignment_X_colwise; - - const e8m0_t *const scales_ptr = - use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) - : reinterpret_cast(input.columnwise_scale_inv.dptr); - - const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; - - const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; - - const dim3 block(THREADS_PER_CHUNK); - const dim3 grid(chunks_X, chunks_Y); - - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); - - dequantize_mxfp8_kernel - <<>>(tensor_map_input, tensor_map_output, scales_ptr, - rows, cols, scales_stride);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + // // TODO: Make more general + // const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + // const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + // const size_t rows = input.flat_first_dim(); + // const size_t cols = input.flat_last_dim(); + // const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + // const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + // const size_t unpadded_scales_Y_rowwise = rows; + // const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + // const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + // const size_t unpadded_scales_X_colwise = cols; + + // const size_t scales_Y_rowwise = + // DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + // scale_tensor_alignment_Y_rowwise; + // const size_t scales_X_rowwise = + // DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + // scale_tensor_alignment_X_rowwise; + // const size_t scales_Y_colwise = + // DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + // scale_tensor_alignment_Y_colwise; + // const size_t scales_X_colwise = + // DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + // scale_tensor_alignment_X_colwise; + + // const e8m0_t *const scales_ptr = + // use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + // : reinterpret_cast(input.columnwise_scale_inv.dptr); + + // const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + // const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + // const dim3 block(THREADS_PER_CHUNK); + // const dim3 grid(chunks_X, chunks_Y); + + // TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + // scale_dim_Y_colwise, SCALE_DIM_Y, + // TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + // scale_dim_X_rowwise, SCALE_DIM_X, + // TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + // input.dtype(), IType, + // TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + // output->dtype(), OType, + + // alignas(64) CUtensorMap tensor_map_input{}; + // alignas(64) CUtensorMap tensor_map_output{}; + + // create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + // SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); + // create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + // SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); + + // dequantize_mxfp8_kernel + // <<>>(tensor_map_input, tensor_map_output, scales_ptr, + // rows, cols, scales_stride);); // NOLINT(*) + // ); // NOLINT(*) + // ); // NOLINT(*) + // ); // NOLINT(*) } } // namespace dequantization @@ -338,18 +338,18 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } + // if (is_tensor_scaling(input.scaling_mode)) { + // dequantization::fp8_dequantize(input, output, stream); + // } else if (is_mxfp_scaling(input.scaling_mode)) { + // if (is_supported_by_CC_100()) { + // dequantization::mxfp8_dequantize(input, output, stream); + // } else { + // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + // } + // } else { + // // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING + // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + // } } } // namespace detail diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 55bc247f70..7c65ae96ac 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -104,6 +104,54 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 + : __int_as_float((254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { + return __int_as_float(biased_exp << FP32_MANTISSA_BITS); +} + +#define CUDA_ARCH_HAS_CVT_FEATURE ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { +#if CUDA_ARCH_HAS_CVT_FEATURE + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -169,6 +217,225 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } +template +struct FPx2 { + T x; + T y; +}; + +template +struct FPx4 { + T x1; + T x2; + T x3; + T x4; +}; + +template +struct Type2x {}; + +template <> +struct Type2x { + using type = float2; +}; + +template <> +struct Type2x { + using type = __nv_bfloat162; +}; + +template <> +struct Type2x { + using type = __half2; +}; + +using floatx2 = FPx2; +using bf16x2 = FPx2; +using fp16x2 = FPx2; +using fp8e4m3x2 = FPx2; +using fp8e5m2x2 = FPx2; + +using floatx4 = FPx4; +using bf16x4 = FPx4; +using fp16x4 = FPx4; +using fp8e4m3x4 = FPx4; +using fp8e5m2x4 = FPx4; + +#include +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; + +static_assert(sizeof(floatx2) == 8); +static_assert(sizeof(bf16x2) == 4); +static_assert(sizeof(fp16x2) == 4); +static_assert(sizeof(fp8e4m3x2) == 2); +static_assert(sizeof(fp8e5m2x2) == 2); +static_assert(sizeof(fp4e2m1x2) == 1); + +static_assert(sizeof(fp4e2m1x4) == 2); + +// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1 + +// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6. + +// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures: +// sm_100a +// sm_101a +// sm_120a + +// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. +// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, +// and the converted values are packed in the destination operand d such that the value +// converted from input a is stored in the upper 4 bits of d and the value converted +// from input b is stored in the lower 4 bits of d. + +// SIMD like "Fused" cast + multiplication (x4) +template +__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, + const float scale) { + const float x0 = static_cast(in01.x) * scale; + const float x1 = static_cast(in01.y) * scale; + const float x2 = static_cast(in23.x) * scale; + const float x3 = static_cast(in23.y) * scale; + out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); +} + +// SIMD like "Fused" cast + multiplication (x2) +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + +__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { + asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 420b9ed3bb..48d12f8734 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -343,22 +343,22 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); - switch (align) { - case Alignment::SAME_ALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); - break; - case Alignment::SAME_UNALIGNED: - unary_kernel<<>>( - input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); - break; - case Alignment::DIFFERENT: { - // If the pointers are aligned differently we cannot vectorize - unary_kernel<1, true, fp32, Param, OP><<>>( - input, noop, output, scale, amax, scale_inv, params, N, N); - break; - } - } + // switch (align) { + // case Alignment::SAME_ALIGNED: + // unary_kernel<<>>( + // input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + // break; + // case Alignment::SAME_UNALIGNED: + // unary_kernel<<>>( + // input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); + // break; + // case Alignment::DIFFERENT: { + // // If the pointers are aligned differently we cannot vectorize + // unary_kernel<1, true, fp32, Param, OP><<>>( + // input, noop, output, scale, amax, scale_inv, params, N, N); + // break; + // } + // } } } @@ -377,22 +377,22 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); - switch (align) { - case Alignment::SAME_ALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); - break; - case Alignment::SAME_UNALIGNED: - unary_grad_kernel<<>>( - grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); - break; - case Alignment::DIFFERENT: { - // If the pointers are aligned differently we cannot vectorize - unary_grad_kernel<1, true, fp32, Param, OP><<>>( - grad, input, output, scale, amax, scale_inv, params, N, N); - break; - } - } + // switch (align) { + // case Alignment::SAME_ALIGNED: + // unary_grad_kernel<<>>( + // grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + // break; + // case Alignment::SAME_UNALIGNED: + // unary_grad_kernel<<>>( + // grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + // break; + // case Alignment::DIFFERENT: { + // // If the pointers are aligned differently we cannot vectorize + // unary_grad_kernel<1, true, fp32, Param, OP><<>>( + // grad, input, output, scale, amax, scale_inv, params, N, N); + // break; + // } + // } } } @@ -464,24 +464,24 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); - switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { - case Alignment::SAME_ALIGNED: - gated_act_kernel - <<>>(input, output, scale, amax, scale_inv, m, n, p, - num_aligned_elements); - break; - case Alignment::SAME_UNALIGNED: - gated_act_kernel - <<>>(input, output, scale, amax, scale_inv, m, n, p, - num_aligned_elements); - break; - case Alignment::DIFFERENT: { - // If the pointers are aligned differently we cannot vectorize - gated_act_kernel<1, true, ComputeType, Param, Activation> - <<>>(input, output, scale, amax, scale_inv, m, n, p, n); - break; - } - } + // switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { + // case Alignment::SAME_ALIGNED: + // gated_act_kernel + // <<>>(input, output, scale, amax, scale_inv, m, n, p, + // num_aligned_elements); + // break; + // case Alignment::SAME_UNALIGNED: + // gated_act_kernel + // <<>>(input, output, scale, amax, scale_inv, m, n, p, + // num_aligned_elements); + // break; + // case Alignment::DIFFERENT: { + // // If the pointers are aligned differently we cannot vectorize + // gated_act_kernel<1, true, ComputeType, Param, Activation> + // <<>>(input, output, scale, amax, scale_inv, m, n, p, n); + // break; + // } + // } } } @@ -571,23 +571,23 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu num_blocks = std::min(num_blocks, max_blocks); switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { - case Alignment::SAME_ALIGNED: - dgated_act_kernel - <<>>(grad, input, output, scale, amax, scale_inv, m, n, - p, num_aligned_elements); - break; - case Alignment::SAME_UNALIGNED: - dgated_act_kernel - <<>>(grad, input, output, scale, amax, scale_inv, m, n, - p, num_aligned_elements); - break; - case Alignment::DIFFERENT: { - // If the pointers are aligned differently we cannot vectorize - dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, scale, amax, scale_inv, m, n, - p, n); - break; - } + // case Alignment::SAME_ALIGNED: + // dgated_act_kernel + // <<>>(grad, input, output, scale, amax, scale_inv, m, n, + // p, num_aligned_elements); + // break; + // case Alignment::SAME_UNALIGNED: + // dgated_act_kernel + // <<>>(grad, input, output, scale, amax, scale_inv, m, n, + // p, num_aligned_elements); + // break; + // case Alignment::DIFFERENT: { + // // If the pointers are aligned differently we cannot vectorize + // dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> + // <<>>(grad, input, output, scale, amax, scale_inv, m, n, + // p, n); + // break; + // } } } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index e6a54108ed..3f5bcc975d 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; using e8m0_t = uint8_t; -constexpr uint32_t FP32_MANTISSA_BITS = 23; -constexpr uint32_t FP32_EXPONENT_BIAS = 127; - -enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; template struct Numeric_Traits; @@ -934,44 +931,6 @@ struct Quantized_Limits { static constexpr float emax_rcp = 1.0 / emax; }; -__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; - } - return exponent; -#endif -} - -__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); -} - } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_