|
45 | 45 | from transformer_engine.jax.activation import activation
|
46 | 46 | from transformer_engine.jax.dense import dense, grouped_dense
|
47 | 47 | from transformer_engine.jax.layernorm_dense import layernorm_dense
|
48 |
| -from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x |
49 | 48 |
|
50 | 49 | from transformer_engine_jax import is_non_nt_fp8_gemm_supported
|
51 | 50 |
|
@@ -1133,15 +1132,15 @@ def ref_func(x, w, gamma, beta):
|
1133 | 1132 | @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
|
1134 | 1133 | @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
|
1135 | 1134 | @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
|
1136 |
| - @pytest.mark.parametrize("use_bias", [True, False]) |
| 1135 | + @pytest_parametrize_wrapper("use_bias", [True, False]) |
1137 | 1136 | @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
|
1138 | 1137 | def test_layernorm_mlp_grad(
|
1139 | 1138 | self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
|
1140 | 1139 | ):
|
1141 | 1140 | """
|
1142 | 1141 | Test layernorm_mlp VJP Rule
|
1143 | 1142 | """
|
1144 |
| - |
| 1143 | + use_jax_dot_for_gemm(enabled=with_jax_gemm) |
1145 | 1144 |
|
1146 | 1145 | # zero_centered_gamma is already tested in TestNorm
|
1147 | 1146 | zero_centered_gamma = False
|
|
0 commit comments