Skip to content

Commit 39f0375

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b1b51c3 commit 39f0375

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

transformer_engine/jax/gemm.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def _gemm_bwd_rule(
225225
if dgrad_overlap_config["method"] == "bulk":
226226
# Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the
227227
# bulk RS overlap without an extra memcpy.
228-
assert wgrad_overlap_config is not None, (
229-
f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
230-
)
228+
assert (
229+
wgrad_overlap_config is not None
230+
), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
231231
dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False)
232232

233233
# Copy transposed input into the DGRAD overlap buffer for bulk AG.
@@ -275,7 +275,6 @@ def _gemm_bwd_rule(
275275
# Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor
276276
dgrad = dgrad_extra_out
277277

278-
279278
# WGRAD w/o Overlap:
280279
# AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P)
281280
#
@@ -663,13 +662,11 @@ def _fp8_gemm_bwd_rule(
663662
dgrad_scale = None
664663
if dgrad_overlap_config is not None:
665664
if dgrad_overlap_config["method"] == "bulk":
666-
assert wgrad_overlap_config is not None, (
667-
f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
668-
)
665+
assert (
666+
wgrad_overlap_config is not None
667+
), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!"
669668
# Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap
670-
dgrad_pre_rs = jax.dlpack.from_dlpack(
671-
tex.get_overlap_buffer(wgrad_overlap_name, False)
672-
)
669+
dgrad_pre_rs = jax.dlpack.from_dlpack(tex.get_overlap_buffer(wgrad_overlap_name, False))
673670
# Copy input into overlap buffer for all-gather
674671
copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True)
675672

@@ -710,15 +707,12 @@ def _fp8_gemm_bwd_rule(
710707
if wgrad_overlap_config is not None:
711708
if wgrad_overlap_config["method"] == "bulk":
712709
# Get all-gathered input from DGRAD bulk overlap
713-
casted_x_t = jax.dlpack.from_dlpack(
714-
tex.get_overlap_buffer(dgrad_overlap_name, False)
715-
)
710+
casted_x_t = jax.dlpack.from_dlpack(tex.get_overlap_buffer(dgrad_overlap_name, False))
716711

717712
elif tex.overlap_buffer_is_fp8(wgrad_overlap_name):
718713
# Set FP8 scale inverse for non-bulk AG overlap
719714
tex.set_overlap_buffer_scale_inverse(
720-
wgrad_overlap_name,
721-
jax.dlpack.to_dlpack(x_scale_inv)
715+
wgrad_overlap_name, jax.dlpack.to_dlpack(x_scale_inv)
722716
)
723717

724718
# WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K)

0 commit comments

Comments
 (0)