diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 46b1e41e66b4..49b84a2b9d68 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2210,7 +2210,7 @@ def hybrid_forward(self, F, x): for i in range(len(chn_list)): net = Net(chn_list[i]) net.initialize(init=init.Constant(1)) - x = mx.nd.zeros((1, 3, 8, 160, 160), ctx=mx.cpu()) + x = mx.nd.zeros((1, 3, 8, 160, 160)) net(x).asnumpy()