Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX support for box_decode (#19750)
Browse files Browse the repository at this point in the history
* box_decode

* fix sanity

Co-authored-by: Wei Chu <weichu@amazon.com>
  • Loading branch information
waytrue17 and Wei Chu authored Jan 15, 2021
1 parent 91e8429 commit d6abe00
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
77 changes: 77 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2946,3 +2946,80 @@ def convert_where(node, **kwargs):
make_node("Where", [name+"_bool", input_nodes[1], input_nodes[2]], [name], name=name)
]
return nodes

@mx_op.register("_contrib_box_decode")
def convert_contrib_box_decode(node, **kwargs):
"""Map MXNet's _contrib_box_decode operator attributes to onnx's operator.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

data = input_nodes[0]
anchors = input_nodes[1]
input_type = kwargs['in_type']
fmt = attrs.get('format', 'center')
std0 = float(attrs.get('std0', '1.'))
std1 = float(attrs.get('std1', '1.'))
std2 = float(attrs.get('std2', '1.'))
std3 = float(attrs.get('std3', '1.'))
clip = float(attrs.get('clip', '-1.'))

if fmt not in ['center', 'corner']:
raise NotImplementedError("format must be either corner or center.")

nodes = [
create_tensor([0], name+'_0', kwargs["initializer"]),
create_tensor([2], name+'_2', kwargs["initializer"]),
create_tensor([4], name+'_4', kwargs["initializer"]),
create_tensor([2], name+'_2f', kwargs["initializer"], dtype='float32'),
create_tensor([clip], name+'_clip', kwargs["initializer"], dtype='float32'),
create_tensor([std0, std1, std2, std3], name+'_std_1d', kwargs["initializer"], dtype='float32'),
create_tensor([1, 4], name+'_std_shape', kwargs["initializer"]),
make_node("Cast", [data], [name+'_data'], to=int(onnx.TensorProto.FLOAT)),
make_node("Cast", [anchors], [name+'_anchors'], to=int(onnx.TensorProto.FLOAT)),
make_node('Reshape', [name+'_std_1d', name+'_std_shape'], [name+'_std']),
make_node("Mul", [name+'_data', name+'_std'], [name+'_mul0_out']),
make_node('Slice', [name+'_mul0_out', name+'_0', name+'_2', name+'_2'], [name+'_data_xy']),
make_node('Slice', [name+'_mul0_out', name+'_2', name+'_4', name+'_2'], [name+'_data_wh']),
]

if fmt == 'corner':
nodes += [
make_node('Slice', [name+'_anchors', name+'_0', name+'_2', name+'_2'], [name+'_slice0_out']),
make_node('Slice', [name+'_anchors', name+'_2', name+'_4', name+'_2'], [name+'_slice1_out']),
make_node('Sub', [name+'_slice1_out', name+'_slice0_out'], [name+'_anchor_wh']),
make_node('Div', [name+'_anchor_wh', name+'_2f'], [name+'_div0_out']),
make_node("Add", [name+'_slice0_out', name+'_div0_out'], [name+'_anchor_xy']),
]
else:
nodes += [
make_node('Slice', [name+'_anchors', name+'_0', name+'_2', name+'_2'], [name+'_anchor_xy']),
make_node('Slice', [name+'_anchors', name+'_2', name+'_4', name+'_2'], [name+'_anchor_wh']),
]

nodes += [
make_node("Mul", [name+'_data_xy', name+'_anchor_wh'], [name+'_mul1_out']),
make_node("Add", [name+'_mul1_out', name+'_anchor_xy'], [name+'_add0_out']),
]

if clip > 0.:
nodes += [
make_node("Less", [name+"_data_wh", name+"_clip"], [name+"_less0_out"]),
make_node('Where', [name+'_less0_out', name+'_data_wh', name+'_clip'], [name+'_where0_out']),
make_node("Exp", [name+'_where0_out'], [name+'_exp0_out']),
]
else:
nodes += [
make_node("Exp", [name+'_data_wh'], [name+'_exp0_out']),
]

nodes += [
make_node("Mul", [name+'_exp0_out', name+'_anchor_wh'], [name+'_mul2_out']),
make_node('Div', [name+'_mul2_out', name+'_2f'], [name+'_div1_out']),
make_node('Sub', [name+'_add0_out', name+'_div1_out'], [name+'_sub0_out']),
make_node('Add', [name+'_add0_out', name+'_div1_out'], [name+'_add1_out']),
make_node('Concat', [name+'_sub0_out', name+'_add1_out'], [name+'concat0_out'], axis=2),
make_node("Cast", [name+'concat0_out'], [name], to=input_type, name=name)
]

return nodes
13 changes: 13 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,16 @@ def test_onnx_export_where(tmp_path, dtype, shape):
y = mx.nd.ones(shape, dtype=dtype)
cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32')
op_export_test('where', M, [cond, x, y], tmp_path)

@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('fmt', ['corner', 'center'])
@pytest.mark.parametrize('clip', [-1., 0., .5, 5.])
def test_onnx_export_contrib_box_decode(tmp_path, dtype, fmt, clip):
# ensure data[0] < data[2] and data[1] < data[3] for corner format
mul = mx.nd.array([-1, -1, 1, 1], dtype=dtype)
data = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) * mul
anchors = mx.nd.random.uniform(0, 1, (1, 3, 4), dtype=dtype) * mul
M1 = def_model('contrib.box_decode', format=fmt, clip=clip)
op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)
M2 = def_model('contrib.box_decode', format=fmt, clip=clip, std0=0.3, std1=1.4, std2=0.5, std3=1.6)
op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)

0 comments on commit d6abe00

Please sign in to comment.