File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
tests/pytorch/distributed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -598,7 +598,7 @@ def _fp8_gemm():
598
598
tex .FP8FwdTensors .GEMM1_INPUT ,
599
599
fp8_dtype ,
600
600
torch .uint8 if opts .fp8_output else torch .bfloat16 ,
601
- te .module .base .get_workspace (),
601
+ te .module .base .get_workspace (). repeat ( 3 ) ,
602
602
bias = None ,
603
603
use_bias = False ,
604
604
gelu = False ,
@@ -639,7 +639,7 @@ def _fp8_gemm2(gemm1_out):
639
639
tex .FP8FwdTensors .GEMM2_INPUT ,
640
640
fp8_dtype ,
641
641
torch .uint8 if opts .fp8_output else torch .bfloat16 ,
642
- te .module .base .get_workspace (),
642
+ te .module .base .get_workspace (). repeat ( 3 ) ,
643
643
bias = None ,
644
644
use_bias = False ,
645
645
gelu = False ,
@@ -662,7 +662,7 @@ def _gemm():
662
662
kernel_t ,
663
663
gemm_inp ,
664
664
torch .bfloat16 ,
665
- te .module .base .get_workspace (),
665
+ te .module .base .get_workspace (). repeat ( 3 ) ,
666
666
bias = None ,
667
667
use_bias = False ,
668
668
gelu = False ,
You can’t perform that action at this time.
0 commit comments