Skip to content

Handle dtypes more carefully in multi-tensor Adam #1888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 182 additions & 51 deletions transformer_engine/common/multi_tensor/adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ struct AdamFunctorMasterParamRemainder {
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);

local_p[ii] = static_cast<int16_t>(p[i]);
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]);
local_p[ii] = p[i];
local_p_rem[ii] = p_remainder[i];
} else {
r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
Expand Down Expand Up @@ -280,8 +280,8 @@ struct AdamFunctorMasterParamRemainder {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]);
p[i] = static_cast<int16_t>(local_p[ii]);
p_remainder[i] = local_p_rem[ii];
p[i] = local_p[ii];

m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
Expand Down Expand Up @@ -466,8 +466,8 @@ struct AdamCapturableFunctor {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = static_cast<T>(r_p[ii]);
m[i] = static_cast<T>(r_m[ii]);
v[i] = static_cast<T>(r_v[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
Expand Down Expand Up @@ -577,43 +577,66 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
// 5 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5,
"Expected 4 or 5 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}

// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(p_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
if (num_tensor_lists == 5) {
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
}

// Check if 64-bit indices are required
bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
if (tensor_lists[i][j]->numel() >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
if (requires_64bit_indexing) {
break;
}
}

const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();

// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");

// Launch kernel
if (requires_64bit_indexing) {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
// g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
Expand All @@ -637,7 +660,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
}
} else {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
// g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
Expand All @@ -647,6 +670,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
Expand All @@ -667,32 +691,50 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();

// case 5: g, p, m, v, p_master
NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5");
NVTE_CHECK(p_in_type_te == DType::kBFloat16,
"Adam with BF16 param remainders requires BF16 params");
// Check tensor list sizes
// 5 tensor lists: g, p, m, v, p_remainder
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 5, "Expected 5 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}

// g, p, m, v, p_master
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(DType::kBFloat16));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kInt16));
}

// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay););

NVTE_CHECK_CUDA(cudaGetLastError());
}

Expand All @@ -702,38 +744,68 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
// Check tensor list sizes
// 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}

// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(
tensor_lists[1][j]->dtype() == fp8_dtype || tensor_lists[1][j]->dtype() == DType::kByte,
"Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(fp8_dtype));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[5][j]->dtype() == DType::kFloat32, "Scale tensor ", j,
" has dtype=", to_string(tensor_lists[5][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[6][j]->dtype() == DType::kFloat32, "Absmax tensor ", j,
" has dtype=", to_string(tensor_lists[6][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[7][j]->dtype() == DType::kFloat32, "Scale-inverse tensor ", j,
" has dtype=", to_string(tensor_lists[7][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}

// Check if 64-bit indices are required
bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
if (tensor_lists[i][j]->numel() >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
if (requires_64bit_indexing) {
break;
}
}

const auto g_in_type_te = tensor_lists[0][0]->dtype();

// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");

// Launch kernel
if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
Expand Down Expand Up @@ -764,6 +836,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}

// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}

// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
Expand All @@ -782,6 +882,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id,
cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 5, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}

// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}

// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OptionalCUDAGuard {

~OptionalCUDAGuard() {
if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_));
cudaSetDevice(prev_device_);
}
}

Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ std::string to_string(const DType type) {
return "Float8E8M0";
case DType::kFloat4E2M1:
return "Float4E2M1";
case DType::kInt16:
return "Int16";
case DType::kInt32:
return "Int32";
case DType::kInt64:
Expand Down