Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[API NEW][SET FUNC] Add set functions (#20693)
Browse files Browse the repository at this point in the history
* [API] Add set functions

* update tests

* fix lint
  • Loading branch information
barry-jin authored Nov 3, 2021
1 parent 9e6dd92 commit 630a144
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/mxnet/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .utils import * # pylint: disable=wildcard-import
from .function_base import * # pylint: disable=wildcard-import
from .stride_tricks import * # pylint: disable=wildcard-import
from .set_functions import * # pylint: disable=wildcard-import
from .io import * # pylint: disable=wildcard-import
from .arrayprint import * # pylint: disable=wildcard-import

Expand Down
113 changes: 113 additions & 0 deletions python/mxnet/numpy/set_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Standard Array API for creating and operating on sets."""

from collections import namedtuple

from ..ndarray import numpy as _mx_nd_np


__all__ = ['unique_all', 'unique_inverse', 'unique_values']


def unique_all(x):
"""
Returns the unique elements of an input array `x`
Notes
-----
`unique_all` is a standard API in
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-all-x
instead of an official NumPy operator.
Parameters
----------
x : ndarray
Input array. This will be flattened if it is not already 1-D.
Returns
-------
out : Tuple[ndarray, ndarray, ndarray, ndarray]
a namedtuple (values, indices, inverse_indices, counts):
values : ndarray
The sorted unique values.
indices : ndarray, optional
The indices of the first occurrences of the unique values in the
original array.
inverse_indices : ndarray
The indices to reconstruct the original array from the
unique array.
counts : ndarray
The number of times each of the unique values comes up in the
original array.
"""
UniqueAll = namedtuple('UniqueAll', ['values', 'indices', 'inverse_indices', 'counts'])
return UniqueAll(*_mx_nd_np.unique(x, True, True, True))


def unique_inverse(x):
"""
Returns the unique elements of an input array `x` and the indices
from the set of unique elements that reconstruct `x`.
Notes
-----
`unique_inverse` is a standard API in
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-inverse-x
instead of an official NumPy operator.
Parameters
----------
x : ndarray
Input array. This will be flattened if it is not already 1-D.
Returns
-------
out : Tuple[ndarray, ndarray]
a namedtuple (values, inverse_indices):
values : ndarray
The sorted unique values.
inverse_indices : ndarray
The indices to reconstruct the original array from the
unique array.
"""
UniqueInverse = namedtuple('UniqueInverse', ['values', 'inverse_indices'])
return UniqueInverse(*_mx_nd_np.unique(x, False, True, False))


def unique_values(x):
"""
Returns the unique elements of an input array `x`.
Notes
-----
`unique_values` is a standard API in
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-values-x
instead of an official NumPy operator.
Parameters
----------
x : ndarray
Input array. This will be flattened if it is not already 1-D.
Returns
-------
out : ndarray
The sorted unique values.
"""
return _mx_nd_np.unique(x, False, False, False)
124 changes: 124 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8079,6 +8079,130 @@ def forward(self, a):
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)


@use_np
@pytest.mark.parametrize('shape,index,inverse,counts', [
((), True, True, True),
((1, ), True, True, True),
((5, ), True, True, True),
((5, ), True, True, True),
((5, 4), True, True, True),
((5, 0, 4), True, True, True),
((0, 0, 0), True, True, True),
((5, 3, 4), True, True, True),
])
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
@pytest.mark.parametrize('hybridize', [False, True])
def test_np_unique_all(shape, index, inverse, counts, dtype, hybridize):
class TestUniqueAll(HybridBlock):
def __init__(self):
super(TestUniqueAll, self).__init__()

def forward(self, a):
return np.unique_all(a)

test_unique = TestUniqueAll()
if hybridize:
test_unique.hybridize()
x = onp.random.uniform(-8.0, 8.0, size=shape)
x = np.array(x, dtype=dtype)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
mx_out = test_unique(x)
for i in range(len(mx_out)):
assert mx_out[i].shape == np_out[i].shape
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)

# Test imperative once again
mx_out = np.unique_all(x)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
assert mx_out.values.shape == np_out[0].shape
assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5)
assert mx_out.indices.shape == np_out[1].shape
assert_almost_equal(mx_out.indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5)
assert mx_out.inverse_indices.shape == np_out[2].shape
assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[2], rtol=1e-3, atol=1e-5)
assert mx_out.counts.shape == np_out[3].shape
assert_almost_equal(mx_out.counts.asnumpy(), np_out[3], rtol=1e-3, atol=1e-5)


@use_np
@pytest.mark.parametrize('shape,index,inverse,counts', [
((), False, True, False),
((1, ), False, True, False),
((5, ), False, True, False),
((5, ), False, True, False),
((5, 4), False, True, False),
((5, 0, 4), False, True, False),
((0, 0, 0), False, True, False),
((5, 3, 4), False, True, False),
])
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
@pytest.mark.parametrize('hybridize', [False, True])
def test_np_unique_inverse(shape, index, inverse, counts, dtype, hybridize):
class TestUniqueInverse(HybridBlock):
def __init__(self):
super(TestUniqueInverse, self).__init__()

def forward(self, a):
return np.unique_inverse(a)

test_unique = TestUniqueInverse()
if hybridize:
test_unique.hybridize()
x = onp.random.uniform(-8.0, 8.0, size=shape)
x = np.array(x, dtype=dtype)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
mx_out = test_unique(x)
for i in range(len(mx_out)):
assert mx_out[i].shape == np_out[i].shape
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)

# Test imperative once again
mx_out = np.unique_inverse(x)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
assert mx_out.values.shape == np_out[0].shape
assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5)
assert mx_out.inverse_indices.shape == np_out[1].shape
assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5)


@use_np
@pytest.mark.parametrize('shape,index,inverse,counts', [
((), False, False, False),
((1, ), False, False, False),
((5, ), False, False, False),
((5, ), False, False, False),
((5, 4), False, False, False),
((5, 0, 4), False, False, False),
((0, 0, 0), False, False, False),
((5, 3, 4), False, False, False),
])
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
@pytest.mark.parametrize('hybridize', [False, True])
def test_np_unique_values(shape, index, inverse, counts, dtype, hybridize):
class TestUniqueValues(HybridBlock):
def __init__(self):
super(TestUniqueValues, self).__init__()

def forward(self, a):
return np.unique_values(a)

test_unique = TestUniqueValues()
if hybridize:
test_unique.hybridize()
x = onp.random.uniform(-8.0, 8.0, size=shape)
x = np.array(x, dtype=dtype)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
mx_out = test_unique(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

# Test imperative once again
mx_out = np.unique_values(x)
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@use_np
def test_np_take():
configs = [
Expand Down

0 comments on commit 630a144

Please sign in to comment.