diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3576242e0d77..1778f606f5df 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1778,44 +1778,39 @@ def convert_slice_axis(node, **kwargs): return nodes -@mx_op.register("SliceChannel") +@mx_op.register('SliceChannel') def convert_slice_channel(node, **kwargs): """Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split operator based on squeeze_axis attribute and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - opset_version = kwargs['opset_version'] - if opset_version < 11: - raise AttributeError('ONNX opset 11 or greater is required to export this operator') + num_outputs = int(attrs.get('num_outputs')) + axis = int(attrs.get('axis', 1)) + squeeze_axis = attrs.get('squeeze_axis', 'False') - num_outputs = int(attrs.get("num_outputs")) - axis = int(attrs.get("axis", 1)) - squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True']) + create_tensor([axis], name+'_axis', kwargs['initializer']) + create_tensor([axis+1], name+'axis_p1', kwargs['initializer']) - if squeeze_axis == 1 and num_outputs == 1: - node = onnx.helper.make_node( - "Squeeze", - input_nodes, - [name], - axes=[axis], - name=name, - ) - return [node] - elif squeeze_axis == 0 and num_outputs > 1: - node = onnx.helper.make_node( - "Split", - input_nodes, - [name+str(i) for i in range(num_outputs)], - axis=axis, - name=name - ) - return [node] + nodes = [] + if squeeze_axis in ['True', '1']: + nodes += [ + make_node('Split', [input_nodes[0]], [name+str(i)+'_' for i in range(num_outputs)], + axis=axis) + ] + for i in range(num_outputs): + nodes += [ + make_node('Squeeze', [name+str(i)+'_'], [name+str(i)], axes=[axis]) + ] else: - raise NotImplementedError("SliceChannel operator with num_outputs>1 and" - "squeeze_axis true is not implemented.") + nodes += [ + make_node('Split', [input_nodes[0]], [name+str(i) for i in range(num_outputs)], + axis=axis) + ] + return nodes @mx_op.register("expand_dims") def convert_expand_dims(node, **kwargs): @@ -3089,6 +3084,7 @@ def convert_contrib_box_nms(node, **kwargs): coord_start = int(attrs.get('coord_start', '2')) score_index = int(attrs.get('score_index', '1')) id_index = int(attrs.get('id_index', '-1')) + force_suppress = attrs.get('force_suppress', 'True') background_id = int(attrs.get('background_id', '-1')) in_format = attrs.get('in_format', 'corner') out_format = attrs.get('out_format', 'corner') @@ -3101,8 +3097,11 @@ def convert_contrib_box_nms(node, **kwargs): if background_id != -1: raise NotImplementedError('box_nms does not currently support background_id != -1') - if id_index != -1: - raise NotImplementedError('box_nms does not currently support id_index != -1') + if id_index != -1 or force_suppress == 'False': + logging.warning('box_nms: id_idex != -1 or/and force_suppress == False detected. ' + 'However, due to ONNX limitations, boxes of different categories will NOT ' + 'be exempted from suppression. This might lead to different behavior than ' + 'native MXNet') nodes = [ create_tensor([coord_start], name+'_cs', kwargs['initializer']), diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f4b44e58f188..977b6dbdf5ef 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -929,6 +929,19 @@ def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group, op_export_test('convolution', M, inputs, tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32']) +@pytest.mark.parametrize('num_outputs', [1, 3, 9]) +@pytest.mark.parametrize('axis', [1, 2, -1, -2]) +@pytest.mark.parametrize('squeeze_axis', [True, False, 0, 1]) +def test_onnx_export_slice_channel(tmp_path, dtype, num_outputs, axis, squeeze_axis): + shape = (3, 9, 18) + if squeeze_axis and shape[axis] != num_outputs: + return + M = def_model('SliceChannel', num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis) + x = mx.random.uniform(0, 1, shape, dtype=dtype) + op_export_test('slice_channel', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float32', 'float64']) @pytest.mark.parametrize('momentum', [0.9, 0.5, 0.1]) def test_onnx_export_batchnorm(tmp_path, dtype, momentum):