Skip to content

Commit da0709a

Browse files
committed
minor unit test cleanup
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent cf1774c commit da0709a

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from transformer_engine.jax.activation import activation
4646
from transformer_engine.jax.dense import dense, grouped_dense
4747
from transformer_engine.jax.layernorm_dense import layernorm_dense
48-
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
4948

5049
from transformer_engine_jax import is_non_nt_fp8_gemm_supported
5150

@@ -1133,15 +1132,15 @@ def ref_func(x, w, gamma, beta):
11331132
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
11341133
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
11351134
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
1136-
@pytest.mark.parametrize("use_bias", [True, False])
1135+
@pytest_parametrize_wrapper("use_bias", [True, False])
11371136
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
11381137
def test_layernorm_mlp_grad(
11391138
self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
11401139
):
11411140
"""
11421141
Test layernorm_mlp VJP Rule
11431142
"""
1144-
1143+
use_jax_dot_for_gemm(enabled=with_jax_gemm)
11451144

11461145
# zero_centered_gamma is already tested in TestNorm
11471146
zero_centered_gamma = False

transformer_engine/jax/cpp_extensions/quantization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
Quantizer,
3737
GroupedQuantizer,
3838
QuantizeLayout,
39-
DelayedScaleQuantizer,
4039
ScalingMode,
4140
compute_scale_from_amax,
4241
)

0 commit comments

Comments
 (0)