Skip to content

Commit b1b51c3

Browse files
committed
updated FWD/BWD wrappers for non-FP8 and FP8 gemm
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent 13a8cd4 commit b1b51c3

File tree

2 files changed

+108
-117
lines changed

2 files changed

+108
-117
lines changed

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def fp8_gemm_impl(
10741074
bias: Optional[ArrayLike] = None,
10751075
gelu_input: Optional[ArrayLike] = None,
10761076
out: Optional[ArrayLike] = None,
1077+
extra_out: Optional[ArrayLike] = None,
10771078
out_amax: Optional[ArrayLike] = None,
10781079
out_scale: Optional[ArrayLike] = None,
10791080
out_dtype: jnp.dtype = jnp.bfloat16,

0 commit comments

Comments
 (0)