diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 2696c20498..2e117eb6b2 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -225,8 +225,8 @@ struct AdamFunctorMasterParamRemainder { r_m[ii] = static_cast(m[i]); r_v[ii] = static_cast(v[i]); - local_p[ii] = static_cast(p[i]); - local_p_rem[ii] = static_cast(p_remainder[i]); + local_p[ii] = p[i]; + local_p_rem[ii] = p_remainder[i]; } else { r_g[ii] = MATH_T(0); r_m[ii] = MATH_T(0); @@ -280,8 +280,8 @@ struct AdamFunctorMasterParamRemainder { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - p_remainder[i] = static_cast(local_p_rem[ii]); - p[i] = static_cast(local_p[ii]); + p_remainder[i] = local_p_rem[ii]; + p[i] = local_p[ii]; m[i] = static_cast(r_m[ii]); v[i] = static_cast(r_v[ii]); @@ -466,8 +466,8 @@ struct AdamCapturableFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -577,9 +577,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - const size_t num_tensors_per_list = tensor_lists[0].size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -587,16 +584,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - size_t max_size = 0; + // Check tensor list sizes + // 4 tensor lists: g, p, m, v + // 5 tensor lists: g, p, m, v, p_master + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, + "Expected 4 or 5 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + const auto p_in_type_te = tensor_lists[1][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(p_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + if (num_tensor_lists == 5) { + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + } + + // Check if 64-bit indices are required bool requires_64bit_indexing = false; for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() > max_size) { - max_size = tensor_lists[i][j]->numel(); - if (max_size >= INT_MAX) { - requires_64bit_indexing = true; - break; - } + if (tensor_lists[i][j]->numel() >= INT_MAX) { + requires_64bit_indexing = true; + break; } } if (requires_64bit_indexing) { @@ -604,16 +633,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - const auto p_in_type_te = tensor_lists[1][0]->dtype(); - - // case 4: g, p, m, v - // case 5: g, p, m, v, p_master - NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5"); - + // Launch kernel if (requires_64bit_indexing) { if (num_tensor_lists == 4) { - // Assume single type across p,g,m1,m2 now + // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -637,7 +660,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, } } else { if (num_tensor_lists == 4) { - // Assume single type across p,g,m1,m2 now + // g, p, m, v TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -647,6 +670,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { + // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -667,8 +691,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -676,23 +698,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - const auto p_in_type_te = tensor_lists[1][0]->dtype(); - - // case 5: g, p, m, v, p_master - NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5"); - NVTE_CHECK(p_in_type_te == DType::kBFloat16, - "Adam with BF16 param remainders requires BF16 params"); + // Check tensor list sizes + // 5 tensor lists: g, p, m, v, p_remainder + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 5, "Expected 5 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } - // g, p, m, v, p_master + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(DType::kBFloat16)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kInt16)); + } + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, AdamFunctorMasterParamRemainder(), device_id, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);); - NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -702,9 +744,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay, const DType fp8_dtype, const int device_id, cudaStream_t stream) { - const size_t num_tensor_lists = tensor_lists.size(); - const size_t num_tensors_per_list = tensor_lists[0].size(); - // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -712,16 +751,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, bias_correction2 = 1 - std::pow(beta2, step); } - size_t max_size = 0; + // Check tensor list sizes + // 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK( + tensor_lists[1][j]->dtype() == fp8_dtype || tensor_lists[1][j]->dtype() == DType::kByte, + "Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(fp8_dtype)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[5][j]->dtype() == DType::kFloat32, "Scale tensor ", j, + " has dtype=", to_string(tensor_lists[5][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[6][j]->dtype() == DType::kFloat32, "Absmax tensor ", j, + " has dtype=", to_string(tensor_lists[6][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[7][j]->dtype() == DType::kFloat32, "Scale-inverse tensor ", j, + " has dtype=", to_string(tensor_lists[7][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Check if 64-bit indices are required bool requires_64bit_indexing = false; for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t j = 0; j < num_tensors_per_list; j++) { - if (tensor_lists[i][j]->numel() > max_size) { - max_size = tensor_lists[i][j]->numel(); - if (max_size >= INT_MAX) { - requires_64bit_indexing = true; - break; - } + if (tensor_lists[i][j]->numel() >= INT_MAX) { + requires_64bit_indexing = true; + break; } } if (requires_64bit_indexing) { @@ -729,11 +805,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, } } - const auto g_in_type_te = tensor_lists[0][0]->dtype(); - - // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv - NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors"); - + // Launch kernel if (requires_64bit_indexing) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, @@ -764,6 +836,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, Tensor step, const int mode, const int bias_correction, const float weight_decay, Tensor inv_scale, const int device_id, cudaStream_t stream) { + // Check tensor list sizes + // 4 tensor lists: g, p, m, v + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 4, "Expected 4 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, @@ -782,6 +882,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, const int bias_correction, const float weight_decay, Tensor inv_scale, const int device_id, cudaStream_t stream) { + // Check tensor list sizes + // 4 tensor lists: g, p, m, v, p_master + const size_t num_tensor_lists = tensor_lists.size(); + NVTE_CHECK(num_tensor_lists == 5, "Expected 4 tensor lists, but found ", num_tensor_lists); + const size_t num_tensors_per_list = tensor_lists[0].size(); + for (size_t i = 1; i < num_tensor_lists; i++) { + NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i, + " has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list); + } + + // Check tensor dtypes + const auto g_in_type_te = tensor_lists[0][0]->dtype(); + for (size_t j = 0; j < num_tensors_per_list; j++) { + NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j, + " has dtype=", to_string(tensor_lists[0][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j, + " has dtype=", to_string(tensor_lists[1][j]->dtype()), + ", but expected dtype=", to_string(g_in_type_te)); + NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j, + " has dtype=", to_string(tensor_lists[2][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j, + " has dtype=", to_string(tensor_lists[3][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j, + " has dtype=", to_string(tensor_lists[4][j]->dtype()), + ", but expected dtype=", to_string(DType::kFloat32)); + } + + // Launch kernel TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 77e4369365..4727f3964f 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -52,7 +52,7 @@ class OptionalCUDAGuard { ~OptionalCUDAGuard() { if (device_changed_) { - NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); + cudaSetDevice(prev_device_); } } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 6c395837fb..858945251f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -46,6 +46,8 @@ std::string to_string(const DType type) { return "Float8E8M0"; case DType::kFloat4E2M1: return "Float4E2M1"; + case DType::kInt16: + return "Int16"; case DType::kInt32: return "Int32"; case DType::kInt64: