Skip to content

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization #1921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 342 additions & 0 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs,
moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
Expand All @@ -25,6 +26,7 @@
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
import copy

seed = 1234
Expand Down Expand Up @@ -646,6 +648,302 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")


def _test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
align_size=16,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")

if num_out_tokens == None:
num_out_tokens = num_tokens * topK

print(
"mask map:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)

# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")

_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = (
torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
)

probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)

tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (
torch.ceil(tokens_per_expert / align_size) * align_size
).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()

permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand(
(num_tokens, hidden_size), dtype=dtype
).cuda()
permute_pad_fwd_input.requires_grad_(True)

restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_probs, _ = fp8_padding(
permuted_probs.unsqueeze(-1), tokens_per_expert_list
)

permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)

# unpadding + unpermute

unpermute_unpad_fwd_input = permuted_paded_output.detach()
unpermute_unpad_fwd_input.requires_grad_(True)

fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
unpermuted_unpaded_output = te_unpermute(
unpaded_output, row_id_map, restore_shape=restore_shape
)

unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)

###################################################################################################################################
#
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_and_pad_fwd_input.requires_grad_(True)
probs = probs.detach()
probs.requires_grad_(True)

(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(
fusion_permute_pad_bwd_input, retain_graph=True
)

# fusion unpad and unpermute
fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_unpad_fwd_input.requires_grad_(True)

fusion_unpermuted_unpaded_output = te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)

fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_unpaded_output.backward(
fusion_unpermute_bwd_input, retain_graph=True
)

###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)

permuted_paded_output_ = permuted_paded_output.float()
fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
fusion_permute_and_pad_fwd_input_grad = (
fusion_permute_and_pad_fwd_input.grad.float()
)

unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
fusion_unpermute_unpad_fwd_input_grad = (
fusion_unpermute_unpad_fwd_input.grad.float()
)

torch.testing.assert_close(
permuted_paded_output_,
fusion_permuted_padded_output_,
msg=f"Mismatch in te_permute_and_pad fwd",
**tols,
)
torch.testing.assert_close(
permute_pad_fwd_input_grad,
fusion_permute_and_pad_fwd_input_grad,
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
torch.testing.assert_close(
unpermuted_unpaded_output_,
fusion_unpermuted_unpaded_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
unpermute_unpad_fwd_input_grad,
fusion_unpermute_unpad_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
torch.testing.assert_close(
permuted_paded_probs.float(),
fusion_permuted_padded_probs.float(),
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)

###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:

def permute_and_pad():
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
fp8_padding(permuted_output, tokens_per_expert_list)
fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)

def fusion_permute_and_pad():
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)

t1 = perf_test_cuda_kernel(lambda: permute_and_pad())

t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())

print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
permuted_paded_output,
permute_pad_bwd_input,
forward_input=[permute_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_permuted_padded_output,
fusion_permute_pad_bwd_input,
forward_input=[fusion_permute_and_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

def unpad_unpermute():
unpaded_output = fp8_unpadding(
unpermute_unpad_fwd_input, tokens_per_expert_list
)
unpermuted_unpaded_output = te_unpermute(
unpaded_output, row_id_map, restore_shape=restore_shape
)

unpermuted_unpaded_output.backward(
unpermute_unpad_bwd_input, retain_graph=True
)

t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
)
print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")

t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
unpermuted_unpaded_output,
unpermute_unpad_bwd_input,
forward_input=([unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_unpaded_output,
fusion_unpermute_bwd_input,
forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")


def _test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
Expand Down Expand Up @@ -1119,6 +1417,40 @@ def test_permutation_mask_map(
)


@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 64, 1280, 7),
(4096, 64, 2048, 6),
(4096, 160, 5120, 6),
(4096, 256, 7168, 8),
(4096, 384, 8192, 8),
(4096, 512, 9216, 8),
],
)
def test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
BENCHMARK = True

_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=BENCHMARK,
)


@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype):
with_probs = True
Expand Down Expand Up @@ -1352,6 +1684,16 @@ def test_permutation_single_case():
BENCHMARK=Benchmark,
)

_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)

_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def torch_version() -> tuple[int, ...]:
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_permute_with_probs,
moe_permute_and_pad_with_probs,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
Expand Down
Loading