From 566d9d348adeb88a2e7c159ae3a6a8ace7d91740 Mon Sep 17 00:00:00 2001 From: Manu Seth <22492939+mseth10@users.noreply.github.com> Date: Mon, 27 Jul 2020 10:04:30 -0700 Subject: [PATCH 1/5] [v1.x] add large matrix tests for linalg ops: det, inverse, trsm, trmm (#18744) * add linalg large matrix tests * add batch inputs linalg tests * reducing bsize to 1 to save time * move matrix generator to utils * passing mat size as arg * import util fn * fix sanity * add mx * call backward * merge fn * update grad value * refactor tests * add mx * add shape check Co-authored-by: Ubuntu --- python/mxnet/test_utils.py | 13 ++++ tests/nightly/test_large_array.py | 120 +++++++++++++++++++++++++++++- 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 9a70f6e268e6..f5d2979c3916 100755 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -297,6 +297,19 @@ def create_vector(size, dtype=np.int64): a = mx.nd.arange(0, size, dtype=dtype) return a +# For testing Large Square Matrix with total size > 2^32 elements +def get_identity_mat(size): + A = mx.nd.zeros((size, size)) + for i in range(size): + A[i, i] = 1 + return A + +# For testing Batch of Large Square Matrix with total size > 2^32 elements +def get_identity_mat_batch(size): + A = get_identity_mat(size) + A_np = A.asnumpy() + return mx.nd.array([A_np, A_np]) + def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, shuffle_csr_indices=False, ctx=None): diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index f2128ba70df1..020a70702501 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -25,7 +25,7 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../python/unittest/')) -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch from mxnet import gluon, nd from common import with_seed, with_post_test_cleanup from nose.tools import with_setup @@ -1207,9 +1207,127 @@ def check_syrk_batch(): assert A.grad[0,0,0] == 4 assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5) + def check_det(): + def run_det(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.det(inp) + return inp.grad, out + + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_det(A) + assert(out.shape == (1,)) + assert(out[0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == 1) + + def check_inverse(): + def run_inverse(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.inverse(inp) + return inp.grad, out + + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_inverse(A) + assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == -1) + + def check_trmm(): + def run_trmm(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.trmm(inp, inp) + return inp.grad, out + + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_trmm(A) + assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == 2) + + def check_trsm(): + def run_trsm(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.trsm(inp, inp) + return inp.grad, out + + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_trsm(A) + assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == 0) + + def check_batch_inverse(): + def run_inverse(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.inverse(inp) + return inp.grad, out + + B = get_identity_mat_batch(LARGE_SQ_X) + grad, out = run_inverse(B) + assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0, 0] == 1) + assert(out[1, 0, 0] == 1) + out.backward() + assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0, 0] == -1) + assert(grad[1, 0, 0] == -1) + + def check_batch_trmm(): + def run_trmm(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.trmm(inp, inp) + return inp.grad, out + + B = get_identity_mat_batch(LARGE_SQ_X) + grad, out = run_trmm(B) + assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0, 0] == 1) + assert(out[1, 0, 0] == 1) + out.backward() + assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0, 0] == 2) + assert(grad[1, 0, 0] == 2) + + def check_batch_trsm(): + def run_trsm(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.trsm(inp, inp) + return inp.grad, out + + B = get_identity_mat_batch(LARGE_SQ_X) + grad, out = run_trsm(B) + assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0, 0] == 1) + assert(out[1, 0, 0] == 1) + out.backward() + assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0, 0] == 0) + assert(grad[1, 0, 0] == 0) + check_potrf() check_potri() check_syrk_batch() + check_det() + check_inverse() + check_trmm() + check_trsm() + check_batch_inverse() + check_batch_trmm() + check_batch_trsm() def test_basic(): From d0093458e3be5e76d78750043c4e5a3f01a7d056 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Mon, 27 Jul 2020 20:28:43 -0700 Subject: [PATCH 2/5] [1.x][LT] Add forward, backward test for linalg.gemm2 (#18784) * added forward, backward test for gemm2 * add backward check * correct gradient assert * move test inside linalg_ops * add shape checks --- tests/nightly/test_large_array.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 020a70702501..306c827bab9f 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1207,6 +1207,25 @@ def check_syrk_batch(): assert A.grad[0,0,0] == 4 assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5) + def check_gemm2(): + def run_gemm2(inp1,inp2): + inp1.attach_grad() + inp2.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.gemm2(inp1,inp2) + return inp1.grad, inp2.grad, out + + inp1=mx.nd.ones(shape=(SMALL_Y, LARGE_X)) + inp1[0][0]=0.1 + inp2=mx.nd.ones(shape=(LARGE_X, SMALL_Y)) + inp1_grad, inp2_grad, out= run_gemm2(inp1,inp2) + assert out.asnumpy()[0][0] == LARGE_X + assert out.shape == (SMALL_Y, SMALL_Y) + out.backward() + assert inp1_grad.shape == (SMALL_Y, LARGE_X) + assert inp2_grad.shape == (LARGE_X, SMALL_Y) + assert_almost_equal(inp2_grad.asnumpy()[0][0],49.1) + def check_det(): def run_det(inp): inp.attach_grad() @@ -1321,6 +1340,7 @@ def run_trsm(inp): check_potrf() check_potri() check_syrk_batch() + check_gemm2() check_det() check_inverse() check_trmm() From 7bef9cb23b72c3b5b93c10d87e09db19f442d12e Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 28 Jul 2020 16:58:07 -0700 Subject: [PATCH 3/5] Back port optimization to broadcast_axis to MXNet1.x (#18773) * Improving performance of broadcast_axis on GPU (#18168) * adding separate int32_t kernel for GPU in broadcast_axis/to/like operators * using structure instead of temp workspace to pass stride and shape * replacing hardcoded int32_t with generic index_t * combining CPU and GPU kernels to leverage cached stride calculation and fast access shape data in both Co-authored-by: Rohit Kumar Srivastava * Improve performance of broadcast_axis on CPU (#17882) * adding comments explaining code optimizations * fixing broadcast_axis kernel to int32 * fixing slice_axis kernel to int32 * combining CPU and GPU implementation method signatures and cleaned up code * adding new broadcast_axis to np_matmul Co-authored-by: Rohit Kumar Srivastava Co-authored-by: Rohit Kumar Srivastava --- src/operator/numpy/np_matmul_op-inl.h | 40 ++++- src/operator/tensor/broadcast_reduce_op.h | 208 ++++++++++++++++++++-- 2 files changed, 224 insertions(+), 24 deletions(-) diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index 89560f64d8c0..8f1b4f9f3a30 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -138,6 +138,8 @@ inline void MatmulImpl(const OpContext& ctx, mshadow::Tensor workspace; mshadow::Tensor ans, mlhs, mrhs; mshadow::Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; + // Is true if either a or b requires broadcast or not if (MatmulNeedBroadcast(a_shape, b_shape)) { // e.g. a.shape = (2, 3, 1, 4, 2) // b.shape = (5, 2, 4) @@ -157,12 +159,38 @@ inline void MatmulImpl(const OpContext& ctx, DType* bc_b_ptr = bc_a_ptr + bc_size_a; MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, { - Kernel, xpu>::Launch( - s, bc_size_a, input_a.dptr(), bc_a_ptr, - k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim); - Kernel, xpu>::Launch( - s, bc_size_b, input_b.dptr(), bc_b_ptr, - k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim); + struct ShapeAndStride aux_data_a, aux_data_b; + PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim); + PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim); + if (isCPU) { + if (!aux_data_a.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } else if (!aux_data_b.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + } else { + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } + } else { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } }); }); ans = mshadow::Tensor(output.dptr(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 5eb0c41aa36c..82b4f7d1f43a 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ #define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ +#include #include #include #include @@ -1037,34 +1038,182 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); } +namespace { // unnamed namespace to keep scope of the struct within the file +struct ShapeAndStride { + index_t in_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t out_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t input_shape[MXNET_SPECIAL_MAX_NDIM]; + index_t output_shape[MXNET_SPECIAL_MAX_NDIM]; + // axes: stores which axes in input is to broadcasted + index_t axes[MXNET_SPECIAL_MAX_NDIM]; + int num_broadcast_axes = -1; + bool shape_changed = false; +}; +} // unnamed namespace + +/*! + * \brief Calculates Stride of input and output tensor dimesnions + And saves mshadow::Shape data in an integer array for + faster access. + * \param *aux_data to hold stride and shape data. + * \param in_shape input shape + * \param out_shape output shape + * \param ndim no of dimensions in output + */ +inline void PrepareAUXData(ShapeAndStride *aux_data, + mshadow::Shape in_shape, + mshadow::Shape out_shape, + int ndim) { + int iter = ndim - 1, i = 0; + aux_data->out_stride[iter] = 1; + aux_data->in_stride[iter] = 1; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } + iter--; + for (; iter >= 0; --iter) { + aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1]; + aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1]; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } + } + aux_data->num_broadcast_axes = i; + assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4); +} + template -struct broadcast_kernel { +struct broadcast_kernel_gpu { template MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, - mshadow::Shape in_shape, - mshadow::Shape out_shape, + const ShapeAndStride& aux_data, const OpReqType req, - const uint32_t ndim) { - size_t in_stride = 1; - size_t out_stride = 1; + const int ndim) { index_t idx = i; index_t in_idx = i; +#pragma unroll 4 for (int iter = ndim - 1; iter >= 0; --iter) { - size_t dim_idx = idx % out_shape[iter]; - in_idx -= dim_idx * out_stride; - if (in_shape[iter] != 1) { - in_idx += dim_idx * in_stride; + index_t out_dim_shape = aux_data.output_shape[iter]; + index_t out_dim_stride = aux_data.out_stride[iter]; + // x % y = x - (x / y) * y + // speeds up modulo(%) operation in GPU + index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; + if (aux_data.input_shape[iter] != 1) { + in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride); + } else { + in_idx -= dim_idx * out_dim_stride; } - idx /= out_shape[iter]; - in_stride *= in_shape[iter]; - out_stride *= out_shape[iter]; + idx /= out_dim_shape; } KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); } }; +/** + * Changed the thread workload mapping from 1 + * thread/output element to 1 thread/input to be broadcasted + * This approach leverages vectorization when fastest varying + * index(stride=1) of the tensor is to be broadcasted. + * In other cases it simply performs better by better load balancing. + */ +template +struct broadcast_kernel_cpu { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const ShapeAndStride& aux_data, + const OpReqType req, + const int ndim) { + index_t idx = i; + index_t init_off = 0; + for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { + size_t dim_idx = idx % aux_data.input_shape[iter]; + init_off += dim_idx * aux_data.out_stride[iter]; + idx /= aux_data.input_shape[iter]; + } + index_t stride_0, stride_1, stride_2; + // Each case is based on the number of axis to be broadcasted + // (1, 2 or 3) after merging axes. + switch (aux_data.num_broadcast_axes) { + // when input shape is one of the following forms + // (x_1,1) or (x_1,1,x_2) or (1,x_1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // in case of (x_1,1) the system leverages vectorization but in other 2 + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 1 : + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + l * stride_0], + req, OP::Map(input[i])); + } + break; + // when input shape is one of the follwing forms + // (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3) + // x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted + // in the inner most loop can be vectorized by compiler in outer loops + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 2: + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0], + req, OP::Map(input[i])); + } + } + break; + // when input shape is of the form (1,x_1,1,x_2,1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // here the last axis which is [4] is the one where compiler can vectorize + // the code the outer 2 loops improve preformance by avoiding + // duplicate stride calculations + // for each output location input[i] needs to be written to. + case 3: + stride_2 = aux_data.out_stride[aux_data.axes[2]]; + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) { + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0], + req, OP::Map(input[i])); + } + } + } + break; + } + } +}; + +template +struct direct_copy { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const OpReqType req) { + KERNEL_ASSIGN(output[i], req, OP::Map(input[i])); + } +}; + +/** + * When CPU context is used the no. of kernel launches are equal to + * the no. of input elements, this helps leverage vectorization when possible + * When GPU context is used no. of kernel launches are equal to + * the no. of output elements, this ensures coalesced memory writes to output + * and improves coalesced memory reads. + */ template inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1076,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; + // combines 2 or more consecutive broadcast/non-broadcast axes together + // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9) + // -> (12,1,5,1,42) (1,3) (50, 9) + // and this is the new input for broadcast_kernel whose total + // num of dimensions cannot be greater than 5(throws an error otherwise). BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; @@ -1091,21 +1246,38 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } - if (dst_shape.ndim() == 2) { + struct ShapeAndStride aux_data; + PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim()); + if (!aux_data.shape_changed) { + // If no broadcast is required (i.e. input_shape == output_shape) + // then simply copy input to outout. + Kernel, xpu>::Launch( + s, outputs[0].Size(), inputs[0].dptr(), outputs[0].dptr(), req[0]); + } else if (dst_shape.ndim() == 2) { Tensor out = outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2); + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor out = outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim); + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } } }); }); From 85eb528c1f53cc3b88ea4596d02d6bfa251b9953 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Tue, 28 Jul 2020 23:08:53 -0700 Subject: [PATCH 4/5] Add syrk test shape check (#18812) * add shape check * add name to contributor.md Co-authored-by: Ubuntu --- CONTRIBUTORS.md | 1 + tests/nightly/test_large_array.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index bd7f966aaa5f..be04d82ffe92 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -252,6 +252,7 @@ List of Contributors * [Oliver Kowalke](https://github.com/olk) * [Connor Goggins](https://github.com/connorgoggins) * [Joe Evans](https://github.com/josephevans) +* [Zhaoqi Zhu](https://github.com/zha0q1) Label Bot --------- diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..8865eae2b81e 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1201,9 +1201,11 @@ def check_syrk_batch(): A.attach_grad() with mx.autograd.record(): out = nd.linalg.syrk(A, alpha=2, transpose=False) + assert out.shape == (2, LARGE_SQ_X, LARGE_SQ_X) assert out[0,0,0] == 2 assert_almost_equal(out[1,0,0], nd.array([0.02]), rtol=1e-3, atol=1e-5) out.backward() + assert A.grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X) assert A.grad[0,0,0] == 4 assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5) From ca6bcf3480f8663b87e55b04d9769c44ec53e727 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 29 Jul 2020 15:16:10 -0700 Subject: [PATCH 5/5] adding error message when attempting to use Large tensor with linalg_syevd (#18807) Co-authored-by: Rohit Kumar Srivastava --- src/operator/tensor/la_op.h | 2 ++ tests/nightly/test_large_array.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index e15390ecde5a..cd097781243b 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -475,6 +475,8 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape& in_a = (*in_attrs)[0]; const mxnet::TShape& out_u = (*out_attrs)[0]; const mxnet::TShape& out_l = (*out_attrs)[1]; + CHECK_LE(in_a.Size(), INT_MAX) + << "Large tensors are not supported by Linear Algebra operator syevd"; if ( in_a.ndim() >= 2 ) { // Forward shape inference. const int ndim(in_a.ndim()); diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 8865eae2b81e..d55d4e55cc6e 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -27,7 +27,8 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch from mxnet import gluon, nd -from common import with_seed, with_post_test_cleanup +from common import with_seed, with_post_test_cleanup, assertRaises +from mxnet.base import MXNetError from nose.tools import with_setup import unittest @@ -1352,6 +1353,16 @@ def run_trsm(inp): check_batch_trsm() +def test_linalg_errors(): + def check_syevd_error(): + A = get_identity_mat(LARGE_SQ_X) + for i in range(LARGE_SQ_X): + A[i,i] = 1 + assertRaises(MXNetError, mx.nd.linalg.syevd, A) + + check_syevd_error() + + def test_basic(): def check_elementwise(): a = nd.ones(shape=(LARGE_X, SMALL_Y))