Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add fix to bce_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 31, 2018
1 parent 98a41af commit e06a1e7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_ce_loss():
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05


@with_seed(1234)
@with_seed()
def test_bce_loss():
N = 20
data = mx.random.uniform(-1, 1, shape=(N, 20))
Expand All @@ -105,7 +105,7 @@ def test_bce_loss():
prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy()))
label_npy = label.asnumpy()
npy_bce_loss = - label_npy * np.log(prob_npy) - (1 - label_npy) * np.log(1 - prob_npy)
assert_almost_equal(mx_bce_loss, npy_bce_loss)
assert_almost_equal(mx_bce_loss, npy_bce_loss, rtol=1e-4, atol=1e-5)

@with_seed()
def test_bce_equal_ce2():
Expand Down

0 comments on commit e06a1e7

Please sign in to comment.