Skip to content

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization #1921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

xiaoxi-wangfj
Copy link
Contributor

@xiaoxi-wangfj xiaoxi-wangfj commented Jul 3, 2025

1.Fused moe_permute_with_probs + Fp8Padding and fused moe_unpermute + Fp8Unpadding, which removes the explicit padding/unpadding in the MOE experts module, improved performance and reduced peak gpu memory usage.
2.Added tests of fused permute/pad and unpermute/unpad operations.

Description

This PR optimizes FP8 MoE permute and pad operations by:

  1. Fusing moe_permute_with_probs + Fp8Padding into moe_permute_and_pad_with_probs
  2. Fusing moe_unpermute + Fp8Unpadding into moe_unpermute with pad_offsets argument
  3. Thereby removing explicit padding/unpadding steps in the MOE experts module

Results:

  • 1.1x~1.6x speedup for fused permute-and-pad operations
  • 1.7x~3x speedup for fused unpermute-and-unpad operations (measured by tests/pytorch/test_permutation.py)
  • Verified in ene-to-end FP8 model training with Megatron framework, +0.4% MFU uplift and ~1GB peak GPU memory reduction in a typical ~600B paramter setup.

Performance data

Tests covering a wide range of model training configurations were performed comparing the fused operations ("Fused:") and the original version ("Orig:"). Running time (in milliseconds) are summarized in the table below and the speedup, measured as the reciprocal of the ratio between running times, are also provided. All tests were carried out with the tests/pytorch/test_permutation.py benchmark script.

Fused-perm-pad

The usage in Megatron-LM

  1. Megatron-LM/megatron/core/transformer/moe/moe_utils.py : Added Support for Fused Operations

`

# Added fused function import
from megatron.core.extensions.transformer_engine import (
    ...,
    fused_permute_and_pad_with_probs,  # [!code ++]
)

def permute(
    ...,
    tokens_per_expert: Optional[torch.Tensor] = None,  # [!code ++]
    align_size: int = -1  # [!code ++]
):
  ...
  if fused and probs is not None:
      if not HAVE_TE or fused_permute_with_probs is None:
          raise ValueError(
              "fused_permute_with_probs is not available. Please install TE >= 2.1.0."
          )
      if tokens_per_expert is not None and align_size > 0:  # [!code ++]
          # Use fused permute+pad operation [!code ++]
          return fused_permute_and_pad_with_probs(tokens, probs, routing_map, tokens_per_expert, align_size)   # [!code ++]
      else:
          # Fallback to original implementation
          ...


def unpermute(
    ...,
    pad_offsets: Optional[torch.Tensor] = None  # [!code ++]
):
    return fused_unpermute(
        ...,
        pad_offsets=pad_offsets  # [!code ++]
    )

`

  1. Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py: Scheduler Integration

`

class _DeepepManager(_DispatchManager):
    def __init__(...):
        self.pad_offsets = None  # [!code ++] Store padding offsets
    
    def get_permuted_hidden_states_by_experts(...):
        ...
        if self.config.moe_permute_padding_for_fp8:# [!code ++]
            # Use fused path [!code ++]
            (                                                          # [!code ++]
                hidden_states,                                         # [!code ++]
                permuted_probs,                                        # [!code ++]
                self.reversed_mapping_for_combine,                     # [!code ++]
                self.pad_offsets,                                      # [!code ++]
                self.tokens_per_expert                                 # [!code ++]
            ) = permute(                                               # [!code ++]
                hidden_states,                                         # [!code ++]
                self.dispatched_routing_map,                           # [!code ++]
                probs=self.dispatched_probs,                           # [!code ++]
                fused=self.permute_fusion,                             # [!code ++]
                tokens_per_expert=self.tokens_per_expert,              # [!code ++]
                align_size=get_fp8_align_size(self.config.fp8_recipe), # [!code ++]
            )                                                          # [!code ++]
        else:
            # Original path
            ...
    
    def get_restored_hidden_states_by_experts(...):
        hidden_states = unpermute(
            ...,
            pad_offsets=self.pad_offsets if self.config.moe_permute_padding_for_fp8 else None, # [!code ++]
        )
        ...

`

Type of change

Documentation change (change only to the documentation, either a fix or a new content)

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added moe_permute_and_pad_with_probs api for fused permute and pad, modified moe_unpermute api with pad_offsets argument for fused unpermute and unpad in transformer_engine/pytorch/permutation.py
  • Added tests in tests/pytorch/test_permutation.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
  that can remove the explicit padding/unpadding in the GroupedMLP layer, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
@xiaoxi-wangfj xiaoxi-wangfj marked this pull request as draft July 4, 2025 04:49
@xiaoxi-wangfj xiaoxi-wangfj marked this pull request as ready for review July 4, 2025 05:03
Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
@yaox12
Copy link
Member

yaox12 commented Jul 11, 2025

Thanks for you contribution.
We also notice there're redundant read and write with the current "permute and then pad" routine. We plan to tackle it by padding the routing map before permutation. Refer to this commit in Megatron-LM.
For permute fusion, we try to avoid tiny cuda kernels as much as possible. I went through your PR and found that these lines may introduce some of them.
We prefer padding routing map instead of fusing permutation and padding because we think that tokens_per_expert is produced by preprocess and we don't want it to be changed later, otherwise it may cause some confusions.
So we won't merge this PR. Thanks for you contribution again.

@xiaoxi-wangfj
Copy link
Contributor Author

Thanks for you contribution. We also notice there're redundant read and write with the current "permute and then pad" routine. We plan to tackle it by padding the routing map before permutation. Refer to this commit in Megatron-LM. For permute fusion, we try to avoid tiny cuda kernels as much as possible. I went through your PR and found that these lines may introduce some of them. We prefer padding routing map instead of fusing permutation and padding because we think that tokens_per_expert is produced by preprocess and we don't want it to be changed later, otherwise it may cause some confusions. So we won't merge this PR. Thanks for you contribution again.

Thank you for your response.

  1. The moe_router_padding_for_fp8 and fused_permute_pad_for_fp8 configurations are compatible. Within the moe_router_padding_for_fp8 logic, if not_enough_tokens_to_pad is triggered, execution will fall back to the fused_permute_pad_for_fp8 computational path. Otherwise, pad_offsets will be set to None."

  2. We previously enabled the Megatron-LM commit configuration you mentioned, but found that pading routing may caused loss instability. During pre-training—especially with larger fp8 align size values like 128 (due to our 1*128 blockwise setting)—it frequently triggered the not_enough_tokens_to_pad warning, and this forced a fallback to explicit padding within GroupedMLP, which halved iteration performance when occurring, and sometimes resulted in loss deterioration. Then we disabled pading routing map.
    This observation motivated our development of Fused_permute_pad. The implementation of Fused_permute_pad will skips any configuration modifications for Fp8Padding/Fp8Unpadding through its update of tokens_per_expert. Regarding refinements to this fused approach, do you have any suggestions for? I hope to continue to refine that optimization."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants