Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[API STD][LINALG] Standardize sort & linalg operators (#20694)
Browse files Browse the repository at this point in the history
* [API] Standardize sort & linalg operators

* fix

* fix tests

* update tests

* fix lint
  • Loading branch information
barry-jin authored Nov 3, 2021
1 parent 75e4d1d commit 9e6dd92
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 160 deletions.
56 changes: 32 additions & 24 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,15 @@ def any(a, axis=None, out=None, keepdims=False):


@set_module('mxnet.ndarray.numpy')
def argsort(a, axis=-1, kind=None, order=None):
def argsort(a, axis=-1, descending=False, stable=True):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Returns the indices that sort an array `x` along a specified axis.
Notes
-----
`argsort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.
Parameters
----------
Expand All @@ -1604,11 +1607,13 @@ def argsort(a, axis=-1, kind=None, order=None):
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.
Returns
-------
Expand Down Expand Up @@ -1659,29 +1664,34 @@ def argsort(a, axis=-1, kind=None, order=None):
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
if order is not None:
raise NotImplementedError("order not supported here")

return _api_internal.argsort(a, axis, True, 'int64')
return _api_internal.argsort(a, axis, not descending, 'int64')


@set_module('mxnet.ndarray.numpy')
def sort(a, axis=-1, kind=None, order=None):
def sort(a, axis=-1, descending=False, stable=True):
"""
Return a sorted copy of an array.
Notes
-----
`sort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.
Parameters
----------
a : ndarray
Array to be sorted.
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.
Returns
-------
Expand All @@ -1704,9 +1714,7 @@ def sort(a, axis=-1, kind=None, order=None):
array([[1, 1],
[3, 4]])
"""
if order is not None:
raise NotImplementedError("order not supported here")
return _api_internal.sort(a, axis, True)
return _api_internal.sort(a, axis, not descending)

@set_module('mxnet.ndarray.numpy')
def dot(a, b, out=None):
Expand Down
14 changes: 12 additions & 2 deletions python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ def svd(a):
return tuple(_api_internal.svd(a))


def cholesky(a):
def cholesky(a, upper=False):
r"""
Cholesky decomposition.
Notes
-----
`upper` param is requested by API standardization in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false
instead of parameter in official NumPy operator.
Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`,
where `L` is lower-triangular and .T is the transpose operator. `a` must be
symmetric and positive-definite. Only `L` is actually returned. Complex-valued
Expand All @@ -463,6 +469,10 @@ def cholesky(a):
----------
a : (..., M, M) ndarray
Symmetric, positive-definite input matrix.
upper : bool
If `True`, the result must be the upper-triangular Cholesky factor.
If `False`, the result must be the lower-triangular Cholesky factor.
Default: `False`.
Returns
-------
Expand Down Expand Up @@ -506,7 +516,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _api_internal.cholesky(a, True)
return _api_internal.cholesky(a, not upper)


def qr(a, mode='reduced'):
Expand Down
46 changes: 32 additions & 14 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,28 @@
__all__ += fallback_linalg.__all__


def matrix_rank(M, tol=None, hermitian=False):
@wrap_data_api_linalg_func
def matrix_rank(M, rtol=None, hermitian=False):
r"""
Return matrix rank of array using SVD method
Rank of the array is the number of singular values of the array that are
greater than `tol`.
greater than `rtol`.
Notes
-----
`matrix_rank` is an alias for `matrix_rank`. It is a standard API in
`rtol` param is requested in array-api-standard in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-matrix-rank-x-rtol-none
instead of an official NumPy operator.
instead of a parameter in official NumPy operator.
Parameters
----------
M : {(M,), (..., M, N)} ndarray
Input vector or stack of matrices.
tol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `tol` is
rtol : (...) ndarray, float, optional
Threshold below which SVD values are considered zero. If `rtol` is
None, and ``S`` is an array with singular values for `M`, and
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
``eps`` is the epsilon value for datatype of ``S``, then `rtol` is
set to ``S.max() * max(M.shape) * eps``.
hermitian : bool, optional
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
Expand All @@ -73,7 +74,7 @@ def matrix_rank(M, tol=None, hermitian=False):
>>> np.linalg.matrix_rank(np.zeros((4,)))
0
"""
return _mx_nd_np.linalg.matrix_rank(M, tol, hermitian)
return _mx_nd_np.linalg.matrix_rank(M, rtol, hermitian)


def matrix_transpose(a):
Expand Down Expand Up @@ -502,22 +503,29 @@ def lstsq(a, b, rcond='warn'):
return _mx_nd_np.linalg.lstsq(a, b, rcond)


def pinv(a, rcond=1e-15, hermitian=False):
@wrap_data_api_linalg_func
def pinv(a, rtol=None, hermitian=False):
r"""
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
Calculate the generalized inverse of a matrix using its
singular-value decomposition (SVD) and including all
*large* singular values.
Notes
-----
`rtol` param is requested in array-api-standard in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-pinv-x-rtol-none
instead of a parameter in official NumPy operator.
Parameters
----------
a : (..., M, N) ndarray
Matrix or stack of matrices to be pseudo-inverted.
rcond : (...) {float or ndarray of float}, optional
rtol : (...) {float or ndarray of float}, optional
Cutoff for small singular values.
Singular values less than or equal to
``rcond * largest_singular_value`` are set to zero.
``rtol * largest_singular_value`` are set to zero.
Broadcasts against the stack of matrices.
hermitian : bool, optional
If True, `a` is assumed to be Hermitian (symmetric if real-valued),
Expand Down Expand Up @@ -567,7 +575,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
>>> (pinv_a - np.dot(pinv_a, np.dot(a, pinv_a))).sum()
array(0.)
"""
return _mx_nd_np.linalg.pinv(a, rcond, hermitian)
return _mx_nd_np.linalg.pinv(a, rtol, hermitian)


def norm(x, ord=None, axis=None, keepdims=False):
Expand Down Expand Up @@ -732,10 +740,16 @@ def svdvals(a):
return s


def cholesky(a):
def cholesky(a, upper=False):
r"""
Cholesky decomposition.
Notes
-----
`upper` param is requested by API standardization in
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-cholesky-x-upper-false
instead of parameter in official NumPy operator.
Return the Cholesky decomposition, `L * L.T`, of the square matrix `a`,
where `L` is lower-triangular and .T is the transpose operator. `a` must be
symmetric and positive-definite. Only `L` is actually returned. Complex-valued
Expand All @@ -745,6 +759,10 @@ def cholesky(a):
----------
a : (..., M, M) ndarray
Symmetric, positive-definite input matrix.
upper : bool
If `True`, the result must be the upper-triangular Cholesky factor.
If `False`, the result must be the lower-triangular Cholesky factor.
Default: `False`.
Returns
-------
Expand Down Expand Up @@ -788,7 +806,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _mx_nd_np.linalg.cholesky(a)
return _mx_nd_np.linalg.cholesky(a, upper)


def qr(a, mode='reduced'):
Expand Down
66 changes: 42 additions & 24 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from ..runtime import Features
from ..context import Context
from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\
is_np_default_dtype, wrap_data_api_statical_func
is_np_default_dtype, wrap_data_api_statical_func,\
wrap_sort_functions
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
Expand Down Expand Up @@ -1973,13 +1974,13 @@ def pick(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')

def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
def sort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sort`.
The arguments are the same as for :py:func:`sort`, with
this array as data.
"""
raise sort(self, axis=axis, kind=kind, order=order)
return sort(self, axis=axis, descending=descending, stable=stable)

def topk(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`topk`.
Expand All @@ -1989,13 +1990,13 @@ def topk(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')

def argsort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
def argsort(self, axis=-1, descending=False, stable=True): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`argsort`.
The arguments are the same as for :py:func:`argsort`, with
this array as data.
"""
return argsort(self, axis=axis, kind=kind, order=order)
return argsort(self, axis=axis, descending=descending, stable=stable)

def argmax_channel(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`argmax_channel`.
Expand Down Expand Up @@ -5896,12 +5897,16 @@ def arctanh(x, out=None, **kwargs):


@set_module('mxnet.numpy')
def argsort(a, axis=-1, kind=None, order=None):
@wrap_sort_functions
def argsort(a, axis=-1, descending=False, stable=True):
"""
Returns the indices that would sort an array.
Perform an indirect sort along the given axis using the algorithm specified
by the `kind` keyword. It returns an array of indices of the same shape as
`a` that index data along the given axis in sorted order.
Returns the indices that sort an array `x` along a specified axis.
Notes
-----
`argsort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#argsort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.
Parameters
----------
Expand All @@ -5910,11 +5915,13 @@ def argsort(a, axis=-1, kind=None, order=None):
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.
Returns
-------
Expand Down Expand Up @@ -5965,26 +5972,37 @@ def argsort(a, axis=-1, kind=None, order=None):
>>> x[ind] # same as np.sort(x, axis=None)
array([0, 2, 2, 3])
"""
return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order)
if stable:
warnings.warn("Currently, MXNet only support quicksort in backend, which is not stable")
return _mx_nd_np.argsort(a, axis=axis, descending=descending)


@set_module('mxnet.numpy')
def sort(a, axis=-1, kind=None, order=None):
@wrap_sort_functions
def sort(a, axis=-1, descending=False, stable=True):
"""
Return a sorted copy of an array.
Notes
-----
`sort` is a standard API in
https://data-apis.org/array-api/latest/API_specification/sorting_functions.html#sort-x-axis-1-descending-false-stable-true
instead of an official NumPy operator.
Parameters
----------
a : ndarray
Array to be sorted.
Array to sort.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
descending : bool, optional
sort order. If `True`, the returned indices sort x in descending order (by value).
If `False`, the returned indices sort x in ascending order (by value).Default: False.
stable : bool, optional
sort stability. If `True`, the returned indices must maintain the relative order
of x values which compare as equal. If `False`, the returned indices may or may not
maintain the relative order of x values which compare as equal. Default: True.
Returns
-------
Expand All @@ -6007,7 +6025,7 @@ def sort(a, axis=-1, kind=None, order=None):
array([[1, 1],
[3, 4]])
"""
return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order)
return _mx_nd_np.sort(a, axis=axis, descending=descending)


@set_module('mxnet.numpy')
Expand Down
Loading

0 comments on commit 9e6dd92

Please sign in to comment.