Skip to content

Commit

Permalink
extend reshape op to allow reverse shape inference (apache#11956)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and aaronmarkham committed Aug 6, 2018
1 parent 0328f5a commit 83c4330
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. Input shape: "
<< arr->shape();
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
Expand Down
36 changes: 30 additions & 6 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
if (d1 == -1) d1 = d0 / d2;
if (d2 == -1) d2 = d0 / d1;
CHECK_EQ(d1 * d2, static_cast<IType>(d0)) <<
CHECK(d1 * d2 == static_cast<IType>(d0) || static_cast<IType>(d0) == IType(0)) <<
"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0;
tmp.push_back(d1);
tmp.push_back(d2);
Expand Down Expand Up @@ -151,13 +151,36 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
return oshape;
}

inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) {
if (in->Size() && out.Size()) {
return true;
} else if (!out.Size()) {
return false;
} else {
int zero_axis = -1;
int non_zero_prod = 1;
for (index_t i = 0; i < in->ndim(); i++) {
if ((*in)[i] == 0) {
if (zero_axis != -1)
return false; // more than 1 zero found.
else
zero_axis = i;
} else {
non_zero_prod *= (*in)[i];
}
}
(*in)[zero_axis] = out.Size() / non_zero_prod;
return true;
}
}

inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
const TShape &dshape = (*in_attrs)[0];
TShape &dshape = (*in_attrs)[0];
if (dshape.ndim() == 0) return false;
TShape oshape;
if (param_.shape.ndim() != 0) {
Expand All @@ -182,14 +205,15 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
oshape[inf_idx] = dshape.Size() / oshape.Size();
}
} else {
return (*out_attrs)[0].ndim();
return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
}
ReverseReshapeInferShape(&dshape, oshape);
CHECK_EQ(oshape.Size(), dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << oshape
<< "\nSource: " << dshape;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
}

inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
Expand Down
35 changes: 26 additions & 9 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,11 +1943,11 @@ def test_bxor(a, b):
test_bmul(a, b)
test_bdiv(a, b)
'''
Flaky Test Disabled due to master build failure:
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
Flaky Test Disabled due to master build failure:
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
Github Issue: https://github.com/apache/incubator-mxnet/issues/11838
test_bmod(a, b)
test_bmod(a, b)
'''
test_bmod_int(a, b)
test_bpow(a, b)
Expand Down Expand Up @@ -2065,6 +2065,23 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
assert np.square(exe.grad_dict['data'].asnumpy() - grad_npy.reshape(src_shape)).mean() < 1E-7, \
'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s'\
%(str(src_shape), str(shape_args), str(reverse), str(dst_shape))

for i in range(len(src_shape)):
holdout_src_shape = list(src_shape)
holdout_src_shape[i] = 0
holdout_src_shape = tuple(holdout_src_shape)
net = mx.sym.Variable('data')
net = mx.sym.elemwise_add(net.reshape(shape_args, reverse=reverse), mx.sym.ones(shape=dst_shape))
input_shape, output_shape, __ = net.infer_shape(data=holdout_src_shape)
assert output_shape[0] == dst_shape, \
'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), str(reverse),
str(dst_shape), str(output_shape[0]))
assert input_shape[0] == src_shape, \
'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = %s, ' \
'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), str(reverse),
str(dst_shape), str(output_shape[0]))

# Test new api (Using shape)
test_cases = [
[(2, 3, 5, 5), (0, -1), False, (2, 75)],
Expand Down Expand Up @@ -6615,7 +6632,7 @@ def test_diag():
w = np.random.randint(2,9)
a_np = np.random.random((h, w)).astype(np.float32)
a = mx.nd.array(a_np).astype('float32')

# k == 0
r = mx.nd.diag(a)
assert_almost_equal(r.asnumpy(), np.diag(a_np))
Expand Down Expand Up @@ -6658,7 +6675,7 @@ def test_diag():
d = np.random.randint(2,9)
a_np = np.random.random((d))
a = mx.nd.array(a_np)

# k is random
k = np.random.randint(-d,d)
r = mx.nd.diag(a, k=k)
Expand Down Expand Up @@ -6725,7 +6742,7 @@ def test_invalid_block_size():
invalid_shape_inp = (n , c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)

test_invalid_depth_dim()
test_invalid_space_dim()
test_invalid_block_size()
Expand Down Expand Up @@ -6771,12 +6788,12 @@ def test_invalid_block_size():
invalid_shape_inp = (n, c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

def test_invalid_depth_dim():
invalid_shape_inp = (n, 0, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)

test_invalid_space_dim()
test_invalid_block_size()
test_invalid_depth_dim()
Expand Down

0 comments on commit 83c4330

Please sign in to comment.