diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 0c7ad472c840..45699f714ed4 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -27,6 +27,7 @@ from .utils import * # pylint: disable=wildcard-import from .function_base import * # pylint: disable=wildcard-import from .stride_tricks import * # pylint: disable=wildcard-import +from .set_functions import * # pylint: disable=wildcard-import from .io import * # pylint: disable=wildcard-import from .arrayprint import * # pylint: disable=wildcard-import diff --git a/python/mxnet/numpy/set_functions.py b/python/mxnet/numpy/set_functions.py new file mode 100644 index 000000000000..20e7980e7e64 --- /dev/null +++ b/python/mxnet/numpy/set_functions.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Standard Array API for creating and operating on sets.""" + +from collections import namedtuple + +from ..ndarray import numpy as _mx_nd_np + + +__all__ = ['unique_all', 'unique_inverse', 'unique_values'] + + +def unique_all(x): + """ + Returns the unique elements of an input array `x` + + Notes + ----- + `unique_all` is a standard API in + https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-all-x + instead of an official NumPy operator. + + Parameters + ---------- + x : ndarray + Input array. This will be flattened if it is not already 1-D. + + Returns + ------- + out : Tuple[ndarray, ndarray, ndarray, ndarray] + a namedtuple (values, indices, inverse_indices, counts): + values : ndarray + The sorted unique values. + indices : ndarray, optional + The indices of the first occurrences of the unique values in the + original array. + inverse_indices : ndarray + The indices to reconstruct the original array from the + unique array. + counts : ndarray + The number of times each of the unique values comes up in the + original array. + """ + UniqueAll = namedtuple('UniqueAll', ['values', 'indices', 'inverse_indices', 'counts']) + return UniqueAll(*_mx_nd_np.unique(x, True, True, True)) + + +def unique_inverse(x): + """ + Returns the unique elements of an input array `x` and the indices + from the set of unique elements that reconstruct `x`. + + Notes + ----- + `unique_inverse` is a standard API in + https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-inverse-x + instead of an official NumPy operator. + + Parameters + ---------- + x : ndarray + Input array. This will be flattened if it is not already 1-D. + + Returns + ------- + out : Tuple[ndarray, ndarray] + a namedtuple (values, inverse_indices): + values : ndarray + The sorted unique values. + inverse_indices : ndarray + The indices to reconstruct the original array from the + unique array. + """ + UniqueInverse = namedtuple('UniqueInverse', ['values', 'inverse_indices']) + return UniqueInverse(*_mx_nd_np.unique(x, False, True, False)) + + +def unique_values(x): + """ + Returns the unique elements of an input array `x`. + + Notes + ----- + `unique_values` is a standard API in + https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-values-x + instead of an official NumPy operator. + + Parameters + ---------- + x : ndarray + Input array. This will be flattened if it is not already 1-D. + + Returns + ------- + out : ndarray + The sorted unique values. + """ + return _mx_nd_np.unique(x, False, False, False) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8bd64c1951cc..1bf8cda49720 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -8079,6 +8079,130 @@ def forward(self, a): assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5) +@use_np +@pytest.mark.parametrize('shape,index,inverse,counts', [ + ((), True, True, True), + ((1, ), True, True, True), + ((5, ), True, True, True), + ((5, ), True, True, True), + ((5, 4), True, True, True), + ((5, 0, 4), True, True, True), + ((0, 0, 0), True, True, True), + ((5, 3, 4), True, True, True), +]) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']) +@pytest.mark.parametrize('hybridize', [False, True]) +def test_np_unique_all(shape, index, inverse, counts, dtype, hybridize): + class TestUniqueAll(HybridBlock): + def __init__(self): + super(TestUniqueAll, self).__init__() + + def forward(self, a): + return np.unique_all(a) + + test_unique = TestUniqueAll() + if hybridize: + test_unique.hybridize() + x = onp.random.uniform(-8.0, 8.0, size=shape) + x = np.array(x, dtype=dtype) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + mx_out = test_unique(x) + for i in range(len(mx_out)): + assert mx_out[i].shape == np_out[i].shape + assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5) + + # Test imperative once again + mx_out = np.unique_all(x) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + assert mx_out.values.shape == np_out[0].shape + assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5) + assert mx_out.indices.shape == np_out[1].shape + assert_almost_equal(mx_out.indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5) + assert mx_out.inverse_indices.shape == np_out[2].shape + assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[2], rtol=1e-3, atol=1e-5) + assert mx_out.counts.shape == np_out[3].shape + assert_almost_equal(mx_out.counts.asnumpy(), np_out[3], rtol=1e-3, atol=1e-5) + + +@use_np +@pytest.mark.parametrize('shape,index,inverse,counts', [ + ((), False, True, False), + ((1, ), False, True, False), + ((5, ), False, True, False), + ((5, ), False, True, False), + ((5, 4), False, True, False), + ((5, 0, 4), False, True, False), + ((0, 0, 0), False, True, False), + ((5, 3, 4), False, True, False), +]) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']) +@pytest.mark.parametrize('hybridize', [False, True]) +def test_np_unique_inverse(shape, index, inverse, counts, dtype, hybridize): + class TestUniqueInverse(HybridBlock): + def __init__(self): + super(TestUniqueInverse, self).__init__() + + def forward(self, a): + return np.unique_inverse(a) + + test_unique = TestUniqueInverse() + if hybridize: + test_unique.hybridize() + x = onp.random.uniform(-8.0, 8.0, size=shape) + x = np.array(x, dtype=dtype) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + mx_out = test_unique(x) + for i in range(len(mx_out)): + assert mx_out[i].shape == np_out[i].shape + assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5) + + # Test imperative once again + mx_out = np.unique_inverse(x) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + assert mx_out.values.shape == np_out[0].shape + assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5) + assert mx_out.inverse_indices.shape == np_out[1].shape + assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5) + + +@use_np +@pytest.mark.parametrize('shape,index,inverse,counts', [ + ((), False, False, False), + ((1, ), False, False, False), + ((5, ), False, False, False), + ((5, ), False, False, False), + ((5, 4), False, False, False), + ((5, 0, 4), False, False, False), + ((0, 0, 0), False, False, False), + ((5, 3, 4), False, False, False), +]) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']) +@pytest.mark.parametrize('hybridize', [False, True]) +def test_np_unique_values(shape, index, inverse, counts, dtype, hybridize): + class TestUniqueValues(HybridBlock): + def __init__(self): + super(TestUniqueValues, self).__init__() + + def forward(self, a): + return np.unique_values(a) + + test_unique = TestUniqueValues() + if hybridize: + test_unique.hybridize() + x = onp.random.uniform(-8.0, 8.0, size=shape) + x = np.array(x, dtype=dtype) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + mx_out = test_unique(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + # Test imperative once again + mx_out = np.unique_values(x) + np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + @use_np def test_np_take(): configs = [