diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 8d5b86341a88..0145e04f800f 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -82,7 +82,7 @@ def test_ce_loss(): assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 -@with_seed(1234) +@with_seed() def test_bce_loss(): N = 20 data = mx.random.uniform(-1, 1, shape=(N, 20)) @@ -105,7 +105,7 @@ def test_bce_loss(): prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy())) label_npy = label.asnumpy() npy_bce_loss = - label_npy * np.log(prob_npy) - (1 - label_npy) * np.log(1 - prob_npy) - assert_almost_equal(mx_bce_loss, npy_bce_loss) + assert_almost_equal(mx_bce_loss, npy_bce_loss, rtol=1e-4, atol=1e-5) @with_seed() def test_bce_equal_ce2():