diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 5713d7d14e2a..ea4c9d5e3d0a 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -22,16 +22,24 @@ from . import fallback_linalg __all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', - 'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank'] + 'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank', 'cross', 'diagonal', 'outer', + 'tensordot', 'trace', 'matrix_transpose'] __all__ += fallback_linalg.__all__ def matrix_rank(M, tol=None, hermitian=False): - r"""Return matrix rank of array using SVD method + r""" + Return matrix rank of array using SVD method Rank of the array is the number of singular values of the array that are greater than `tol`. + Notes + ----- + `matrix_rank` is an alias for `matrix_rank`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-rank-x-rtol-none + instead of an official NumPy operator. + Parameters ---------- M : {(M,), (..., M, N)} ndarray @@ -44,7 +52,7 @@ def matrix_rank(M, tol=None, hermitian=False): hermitian : bool, optional If True, `M` is assumed to be Hermitian (symmetric if real-valued), enabling a more efficient method for finding singular values. - Defaults to False. + Default: False. Returns ------- @@ -54,19 +62,317 @@ def matrix_rank(M, tol=None, hermitian=False): Examples -------- >>> from mxnet import np - >>> np.matrix_rank(np.eye(4)) # Full rank matrix + >>> np.linalg.matrix_rank(np.eye(4)) # Full rank matrix 4 >>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix - >>> np.matrix_rank(I) + >>> np.linalg.matrix_rank(I) 3 - >>> np.matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0 + >>> np.linalg.matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0 1 - >>> np.matrix_rank(np.zeros((4,))) + >>> np.linalg.matrix_rank(np.zeros((4,))) 0 """ return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian) +def matrix_transpose(a): + r""" + Transposes a matrix (or a stack of matrices) `a`. + + Notes + ----- + `matrix_transpose` is an alias for `transpose`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-transpose-x + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + Input array having shape (..., M, N) and whose innermost two dimensions form MxN matrices. + + Returns + ---------- + out : ndarray + An array containing the transpose for each matrix and having shape (..., N, M). + The returned array must have the same data type as `a`. + + Examples + -------- + >>> x = np.arange(4).reshape((2,2)) + >>> x + array([[0., 1.], + [2., 3.]]) + >>> np.transpose(x) + array([[0., 2.], + [1., 3.]]) + >>> x = np.ones((1, 2, 3)) + >>> np.transpose(x, (1, 0, 2)).shape + (2, 1, 3) + """ + return _mx_nd_np.transpose(a, axes=None) + + +def trace(a, offset=0): + r""" + Returns a tensor contraction of `a` and `b` over specific axes. + + Notes + ----- + `trace` is an alias for `trace`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-trace-x-offset-0 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + Input array having shape (..., M, N) and whose innermost two dimensions form MxN matrices. + Should have a numeric data type. + offset : int + Offset specifying the off-diagonal relative to the main diagonal. + + offset = 0 : the main diagonal. + offset > 0 : off-diagonal above the main diagonal. + offset < 0 : off-diagonal below the main diagonal. + + Default: 0. + + Returns + ---------- + out : ndarray + An array containing the traces and whose shape is determined by removing the last two dimensions and storing + the traces in the last array dimension. For example, if `a` has rank `k` and shape `(I, J, K, ..., L, M, N)`, + then an output array has rank `k-2` and shape `(I, J, K, ..., L)` + where: `out[i, j, k, ..., l] = trace(a[i, j, k, ..., l, :, :])` + The returned array must have the same data type as `a`. + + Examples + -------- + >>> x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + >>> np.linalg.trace(x) + array(3.) + >>> x = np.arange(8).reshape((2, 2, 2)) + >>> np.linalg.trace(x) + array([6., 8.]) + >>> x = np.arange(24).reshape((2, 2, 2, 3)) + >>> np.linalg.trace(x).shape + (2, 3) + >>> np.linalg.trace(x) + array([[18., 20., 22.], + [24., 26., 28.]]) + """ + # axis1, axis2: defaults are the first two axes of `a`. + return _mx_nd_np.trace(a, offset=offset, axis1=0, axis2=1, out=None) + + +def tensordot(a, b, axes=2): + r""" + Returns a tensor contraction of `a` and `b` over specific axes. + + Notes + ----- + `tensordot` is an alias for `tensordot`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-tensordot-x1-x2-axes-2 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + First input array. Should have a numeric data type. + b : ndarray + Second input array. Must be compatible with `a` (see Broadcasting). Should have a numeric data type. + axes : int, tuple + Number of axes to contract or explicit sequences of axes for `a` and `b`, respectively. + If axes is an int equal to `N` , then contraction must be performed over the last `N` axes of `a` + and the first `N` axes of `b` in order. + The size of each corresponding axis (dimension) must match. Must be nonnegative. + + If N equals 0 , the result is the tensor (outer) product. + If N equals 1 , the result is the tensor dot product. + If N equals 2 , the result is the tensor double contraction (default). + + Default: 2. + + Returns + ---------- + out : ndarray + An array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the + first array `a`, followed by the non-contracted axes (dimensions) of the second array `b`. + + Examples + -------- + >>> x = np.arange(60.).reshape(3,4,5) + >>> y = np.arange(24.).reshape(4,3,2) + >>> z = np.linalg.tensordot(x, y, axes=([1,0],[0,1])) + >>> z.shape + (5, 2) + >>> z + array([[ 4400., 4730.], + [ 4532., 4874.], + [ 4664., 5018.], + [ 4796., 5162.], + [ 4928., 5306.]]) + """ + return _mx_nd_np.tensordot(a, b, axes) + + +def diagonal(a, offset=0): + r""" + Returns the specified diagonals of a matrix (or a stack of matrices) `a`. + + Notes + ----- + `diagonal` is an alias for `diagonal`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-diagonal-x-offset-0 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + The array to apply diag method. + offset : int + Extracts or constructs kth diagonal given input array. + Offset specifying the off-diagonal relative to the main diagonal. + + offset = 0 : the main diagonal. + offset > 0 : off-diagonal above the main diagonal. + offset < 0 : off-diagonal below the main diagonal. + + Default: 0. + + Returns + ---------- + out : ndarray + An array containing the diagonals and whose shape is determined by removing the last two dimensions and + appending a dimension equal to the size of the resulting diagonals. + The returned array must have the same data type as a. + + Examples + -------- + >>> x = np.arange(9).reshape((3,3)) + >>> x + array([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]) + >>> np.linalg.diagonal(x) + array([0., 4., 8.]) + >>> np.linalg.diagonal(x, offset=1) + array([1., 5.]) + >>> np.linalg.diagonal(x, offset=-1) + array([3., 7.]) + """ + return _mx_nd_np.diag(a, k=offset) + + +def cross(a, b, axis=-1): + r""" + Returns the cross product of 3-element vectors. + + If `a` and `b` are multi-dimensional arrays (i.e., both have a rank greater than 1), + then the cross-product of each pair of corresponding 3-element vectors is independently computed. + + Notes + ----- + `cross` is an alias for `cross`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cross-x1-x2-axis-1 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + First input array. Should have a numeric data type. + b : ndarray + Second input array. Must have the same shape as a. Should have a numeric data type. + axis : int + If defined, the axis of `a` and `b` that defines the vector(s) and cross product(s). + + Default: -1. + + Returns + ------- + out : (...) ndarray + An array containing the cross products. + + Examples + -------- + Vector cross-product. + + >>> x = np.array([1., 2., 3.]) + >>> y = np.array([4., 5., 6.]) + >>> np.linalg.cross(x, y) + array([-3., 6., -3.]) + + One vector with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5., 6.]) + >>> np.linalg.cross(x, y) + array([12., -6., -3.]) + + Equivalently: + + >>> x = np.array([1., 2., 0.]) + >>> y = np.array([4., 5., 6.]) + >>>np.linalg.cross(x, y) + array([12., -6., -3.]) + + Both vectors with dimension 2. + + >>> x = np.array([1., 2.]) + >>> y = np.array([4., 5.]) + >>> np.linalg.cross(x, y) + array(-3.) + + Multiple vector cross-products. Note that the direction of the cross + product vector is defined by the `right-hand rule`. + + >>> x = np.array([[1., 2., 3.], [4., 5., 6.]]) + >>> y = np.array([[4., 5., 6.], [1., 2., 3.]]) + >>> np.linalg.cross(x, y) + array([[-3., 6., -3.], + [ 3., -6., 3.]]) + """ + # For a given API standard, the axis of axisa, axisb, axisc are equal to the axis + return _mx_nd_np.cross(a, b, axisa=axis, axisb=axis, axisc=axis, axis=axis) + + +def outer(a, b): + r""" + Computes the outer product of two vectors `a` and `b`. + + Notes + ----- + `outer` is an alias for `outer`. It is a standard API in + https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-outer-x1-x2 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + One-dimensional input array of size `N` . Should have a numeric data type. + b : ndarray + One-dimensional input array of size `M` . Should have a numeric data type. + + Returns + ------- + out : ndarray + A two-dimensional array containing the outer product and whose shape is `(N, M)`. + The returned array must have a data type determined by Type Promotion Rules. + + Examples + -------- + Make a (*very* coarse) grid for computing a Mandelbrot set: + + >>> x = np.linalg.outer(np.ones((5,)), np.linspace(-2, 2, 5)) + >>> x + array([[-2., -1., 0., 1., 2.], + [-2., -1., 0., 1., 2.], + [-2., -1., 0., 1., 2.], + [-2., -1., 0., 1., 2.], + [-2., -1., 0., 1., 2.]]) + """ + return _mx_nd_np.tensordot(a.flatten(), b.flatten(), 0) + + def lstsq(a, b, rcond='warn'): r""" Return the least-squares solution to a linear matrix equation. @@ -207,7 +513,8 @@ def pinv(a, rcond=1e-15, hermitian=False): def norm(x, ord=None, axis=None, keepdims=False): - r"""Matrix or vector norm. + r""" + Matrix or vector norm. This function can only support Frobenius norm for now. The Frobenius norm is given by [1]_: @@ -271,7 +578,8 @@ def norm(x, ord=None, axis=None, keepdims=False): def svd(a): - r"""Singular Value Decomposition. + r""" + Singular Value Decomposition. When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``, where `ut` and `v` are 2D orthonormal arrays and `s` is a 1D @@ -772,7 +1080,8 @@ def tensorsolve(a, b, axes=None): def eigvals(a): - r"""Compute the eigenvalues of a general matrix. + r""" + Compute the eigenvalues of a general matrix. Main difference between `eigvals` and `eig`: the eigenvectors aren't returned. @@ -840,7 +1149,8 @@ def eigvals(a): def eigvalsh(a, UPLO='L'): - r"""Compute the eigenvalues real symmetric matrix. + r""" + Compute the eigenvalues real symmetric matrix. Main difference from eigh: the eigenvectors are not computed. @@ -964,7 +1274,8 @@ def eig(a): def eigh(a, UPLO='L'): - r"""Return the eigenvalues and eigenvectors real symmetric matrix. + r""" + Return the eigenvalues and eigenvectors real symmetric matrix. Returns two objects, a 1-D array containing the eigenvalues of `a`, and a 2-D square array or matrix (depending on the input type) of the