diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index bc4b41497ac7..57ef546de29f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2778,7 +2778,86 @@ def convert_arange(node, **kwargs): create_const_scalar_node(name+"_start", np.array([start], dtype=dtype), kwargs), create_const_scalar_node(name+"_stop", np.array([stop], dtype=dtype), kwargs), create_const_scalar_node(name+"_step", np.array([step], dtype=dtype), kwargs), - make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name]) + make_node("Range", [name+"_start", name+"_stop", name+"_step"], [name], name=name) ] return nodes + + +@mx_op.register('repeat') +def convert_repeat(node, **kwargs): + """Map MXNet's repeat operator attributes to onnx's Tile operator. + """ + from onnx.helper import make_node + from onnx import TensorProto + name, input_nodes, attrs = get_inputs(node, kwargs) + + opset_version = kwargs['opset_version'] + if opset_version < 11: + raise AttributeError('ONNX opset 11 or greater is required to export this operator') + + repeats = int(attrs.get('repeats', 1)) + axis = attrs.get('axis', 'None') + + if repeats <= 0: + raise NotImplementedError('repeat operator does not support parameter repeats==0') + + nodes = [] + if axis == 'None': + nodes += [ + create_tensor([repeats], name+'_rep', kwargs['initializer']), + create_tensor([1, repeats], name+'_repeats', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('ReduceProd', [name+'_shape'], [name+'_size']), + make_node('Reshape', [input_nodes[0], name+'_size'], [name+'_flat']), + make_node('Unsqueeze', [name+'_flat'], [name+'_unsqueeze'], axes=[-1]), + make_node('Tile', [name+'_unsqueeze', name+'_repeats'], [name+'_tile']), + make_node('Mul', [name+'_size', name+'_rep'], [name+'_new_size']), + make_node('Reshape', [name+'_tile', name+'_new_size'], [name], name=name) + ] + else: + axis = int(axis) + repeats -= 1 + nodes += [ + create_tensor([repeats], name+'_repeats', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([], name+'_void', kwargs['initializer']), + create_tensor([axis], name+'_axis', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Reshape', [name+'_dim', name+'_void'], [name+'_dim_s']), + make_node('Range', [name+'_0', name+'_dim_s', name+'_1'], [name+'_range']) + ] + if axis < 0: + nodes += [ + make_node('Add', [name+'_axis', name+'_dim'], [name+'_true_axis']), + make_node('Equal', [name+'_range', name+'_true_axis'], [name+'_one_hot']) + ] + else: + nodes += [ + make_node('Equal', [name+'_range', name+'_axis'], [name+'_one_hot']) + ] + nodes += [ + make_node('Cast', [name+'_one_hot'], [name+'_one_hot_int'], to=int(TensorProto.INT64)), + make_node('Mul', [name+'_repeats', name+'_one_hot_int'], [name+'_mul']), + make_node('Add', [name+'_mul', name+'_1'], [name+'_add']), + make_node('Concat', [name+'_1', name+'_add'], [name+'_repeats_tensor'], axis=0) + ] + if axis == -1: + nodes += [ + make_node('Concat', [name+'_shape', name+'_1'], [name+'_unsqueeze_shape'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_unsqueeze_shape'], + [name+'_unsqueeze']) + ] + else: + nodes += [ + make_node('Unsqueeze', [input_nodes[0]], [name+'_unsqueeze'], axes=[axis+1]) + ] + nodes += [ + make_node('Tile', [name+'_unsqueeze', name+'_repeats_tensor'], [name+'_tile']), + make_node('Mul', [name+'_shape', name+'_add'], [name+'_new_shape']), + make_node('Reshape', [name+'_tile', name+'_new_shape'], [name], name=name) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 78000089868c..c17a03bc3276 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -353,6 +353,15 @@ def test_onnx_export_softmax(tmp_path, dtype): op_export_test('softmax_4', M4, [x, l4], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('axis', [None, 0, 1, 2, -1, -2, -3]) +@pytest.mark.parametrize('repeats', [2, 1, 3]) +def test_onnx_export_repeat(tmp_path, dtype, axis, repeats): + x = mx.nd.arange(0, 27, dtype=dtype).reshape((3, 3, 3)) + M = def_model('repeat', axis=axis, repeats=repeats) + op_export_test('repeat', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('params', [{'height': 7, 'width': 13}, {'height': 10, 'width': 16}, @@ -369,4 +378,3 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params): x = mx.nd.arange(0, 160).reshape((2, 2, 5, 8)) M = def_model('contrib.BilinearResize2D', **params) op_export_test('contrib_BilinearResize2D', M, [x], tmp_path) -