Skip to content

Commit 5517ecf

Browse files
committed
fixed workspace allocation for TP overlap test with pure GEMM
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent 3d7ff1c commit 5517ecf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/pytorch/distributed/run_gemm_with_overlap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def _fp8_gemm():
598598
tex.FP8FwdTensors.GEMM1_INPUT,
599599
fp8_dtype,
600600
torch.uint8 if opts.fp8_output else torch.bfloat16,
601-
te.module.base.get_workspace(),
601+
te.module.base.get_workspace().repeat(3),
602602
bias=None,
603603
use_bias=False,
604604
gelu=False,
@@ -639,7 +639,7 @@ def _fp8_gemm2(gemm1_out):
639639
tex.FP8FwdTensors.GEMM2_INPUT,
640640
fp8_dtype,
641641
torch.uint8 if opts.fp8_output else torch.bfloat16,
642-
te.module.base.get_workspace(),
642+
te.module.base.get_workspace().repeat(3),
643643
bias=None,
644644
use_bias=False,
645645
gelu=False,
@@ -662,7 +662,7 @@ def _gemm():
662662
kernel_t,
663663
gemm_inp,
664664
torch.bfloat16,
665-
te.module.base.get_workspace(),
665+
te.module.base.get_workspace().repeat(3),
666666
bias=None,
667667
use_bias=False,
668668
gelu=False,

0 commit comments

Comments
 (0)