@@ -225,9 +225,9 @@ def _gemm_bwd_rule(
225
225
if dgrad_overlap_config ["method" ] == "bulk" :
226
226
# Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the
227
227
# bulk RS overlap without an extra memcpy.
228
- assert wgrad_overlap_config is not None , (
229
- f"Missing comm+GEMM overlap config for { wgrad_overlap_name } !"
230
- )
228
+ assert (
229
+ wgrad_overlap_config is not None
230
+ ), f"Missing comm+GEMM overlap config for { wgrad_overlap_name } !"
231
231
dgrad_pre_rs = tex .get_overlap_buffer (wgrad_overlap_name , False )
232
232
233
233
# Copy transposed input into the DGRAD overlap buffer for bulk AG.
@@ -275,7 +275,6 @@ def _gemm_bwd_rule(
275
275
# Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor
276
276
dgrad = dgrad_extra_out
277
277
278
-
279
278
# WGRAD w/o Overlap:
280
279
# AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P)
281
280
#
@@ -663,13 +662,11 @@ def _fp8_gemm_bwd_rule(
663
662
dgrad_scale = None
664
663
if dgrad_overlap_config is not None :
665
664
if dgrad_overlap_config ["method" ] == "bulk" :
666
- assert wgrad_overlap_config is not None , (
667
- f"Missing comm+GEMM overlap config for { wgrad_overlap_name } !"
668
- )
665
+ assert (
666
+ wgrad_overlap_config is not None
667
+ ), f"Missing comm+GEMM overlap config for { wgrad_overlap_name } !"
669
668
# Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap
670
- dgrad_pre_rs = jax .dlpack .from_dlpack (
671
- tex .get_overlap_buffer (wgrad_overlap_name , False )
672
- )
669
+ dgrad_pre_rs = jax .dlpack .from_dlpack (tex .get_overlap_buffer (wgrad_overlap_name , False ))
673
670
# Copy input into overlap buffer for all-gather
674
671
copy_into_overlap_buffer (casted_x_t , dgrad_overlap_name , True )
675
672
@@ -710,15 +707,12 @@ def _fp8_gemm_bwd_rule(
710
707
if wgrad_overlap_config is not None :
711
708
if wgrad_overlap_config ["method" ] == "bulk" :
712
709
# Get all-gathered input from DGRAD bulk overlap
713
- casted_x_t = jax .dlpack .from_dlpack (
714
- tex .get_overlap_buffer (dgrad_overlap_name , False )
715
- )
710
+ casted_x_t = jax .dlpack .from_dlpack (tex .get_overlap_buffer (dgrad_overlap_name , False ))
716
711
717
712
elif tex .overlap_buffer_is_fp8 (wgrad_overlap_name ):
718
713
# Set FP8 scale inverse for non-bulk AG overlap
719
714
tex .set_overlap_buffer_scale_inverse (
720
- wgrad_overlap_name ,
721
- jax .dlpack .to_dlpack (x_scale_inv )
715
+ wgrad_overlap_name , jax .dlpack .to_dlpack (x_scale_inv )
722
716
)
723
717
724
718
# WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K)
0 commit comments