diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index d55d4e55cc6e..1f4443aac175 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1211,23 +1211,49 @@ def check_syrk_batch(): assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5) def check_gemm2(): - def run_gemm2(inp1,inp2): + def run_gemm2(inp1, inp2): inp1.attach_grad() inp2.attach_grad() with mx.autograd.record(): - out = mx.nd.linalg.gemm2(inp1,inp2) + out = mx.nd.linalg.gemm2(inp1, inp2) return inp1.grad, inp2.grad, out - inp1=mx.nd.ones(shape=(SMALL_Y, LARGE_X)) - inp1[0][0]=0.1 - inp2=mx.nd.ones(shape=(LARGE_X, SMALL_Y)) - inp1_grad, inp2_grad, out= run_gemm2(inp1,inp2) + inp1 = mx.nd.ones(shape=(SMALL_Y, LARGE_X)) + perturbation = 0.2 + inp1[0][0] = perturbation + inp2 = mx.nd.ones(shape=(LARGE_X, SMALL_Y)) + inp1_grad, inp2_grad, out = run_gemm2(inp1, inp2) assert out.asnumpy()[0][0] == LARGE_X assert out.shape == (SMALL_Y, SMALL_Y) out.backward() assert inp1_grad.shape == (SMALL_Y, LARGE_X) assert inp2_grad.shape == (LARGE_X, SMALL_Y) - assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1) + assert_almost_equal(inp1_grad.asnumpy()[0][0], SMALL_Y) + assert_almost_equal(inp2_grad.asnumpy()[0][0], SMALL_Y - (1 - perturbation)) + + def check_gemm(): + def run_gemm(inp1,inp2, inp3): + inp1.attach_grad() + inp2.attach_grad() + inp3.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.gemm(inp1, inp2, inp3, transpose_b=True) + return inp1.grad, inp2.grad, inp3.grad, out + + inp1 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X)) + perturbation = 0.2 + inp1[0][0][0] = perturbation + inp2 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X)) + inp3 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, SMALL_Y)) + inp1_grad, inp2_grad, inp3_grad, out= run_gemm(inp1, inp2, inp3) + assert_almost_equal(out.asnumpy()[0][0][0], MEDIUM_X + perturbation) + assert out.shape == inp3.shape + out.backward() + assert inp1_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) + assert inp2_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) + assert inp3_grad.shape == (MEDIUM_X, SMALL_Y, SMALL_Y) + assert_almost_equal(inp1_grad.asnumpy()[0][0][0], SMALL_Y) + assert_almost_equal(inp2_grad.asnumpy()[0][0][0], SMALL_Y - (1 - perturbation)) def check_det(): def run_det(inp): @@ -1340,6 +1366,7 @@ def run_trsm(inp): assert(grad[0, 0, 0] == 0) assert(grad[1, 0, 0] == 0) + check_gemm() check_potrf() check_potri() check_syrk_batch()