Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix batchnorm problem with sparse matrices when fix_gamma=True (#11656)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and marcoabreu committed Jul 13, 2018
1 parent d611037 commit 5b4d528
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/operator/batch_norm_v1.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ the output. It is often used during inference.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0.
There's no sparse support for this operator, and it will exhibit problematic behavior if used with
sparse tensors.
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
Expand Down
46 changes: 30 additions & 16 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
using namespace mshadow;
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 3U);
const TShape &dshape = in_shape->at(batchnorm::kData);

const size_t channelAxis = static_cast<size_t>(param.axis < 0
Expand Down Expand Up @@ -444,27 +445,37 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
}
FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif

static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 5);
CHECK_EQ(out_attrs->size(), 3);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
}
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
}
bool dispatched = false;
#if MXNET_USE_MKLDNN == 1
if (!dispatched) {
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
}
#else
for (int& v : *in_attrs)
if (v == - 1) v = kDefaultStorage;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
#endif
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) {
LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647";
}
return dispatched;
}

std::vector<nnvm::NodeEntry> BatchNormGrad(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
Expand Down Expand Up @@ -552,6 +563,11 @@ axis to be the last item in the input shape.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0.
Note::
When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False,
the sparse tensors will fallback.
)code" ADD_FILELINE)
.set_num_inputs(5)
.set_num_outputs(3)
Expand All @@ -574,9 +590,7 @@ then set ``gamma`` to 1 and its gradient to 0.
})
.set_attr<nnvm::FInferShape>("FInferShape", BatchNormShape)
.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#endif
.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
Expand Down Expand Up @@ -607,8 +621,8 @@ then set ``gamma`` to 1 and its gradient to 0.
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", backward_BatchNormStorageType)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def check_batchnorm_training(stype):
mx.nd.array(beta).tostype(stype)]
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]

test = mx.symbol.BatchNorm(data, fix_gamma=True)
test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)

stypes = ['row_sparse', 'default']
Expand Down
4 changes: 1 addition & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,9 +1552,7 @@ def check_batchnorm_training(stype):
test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

stypes = ['default']
for stype in stypes:
check_batchnorm_training(stype)
check_batchnorm_training('default')


@with_seed()
Expand Down
75 changes: 74 additions & 1 deletion tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.

from mxnet.test_utils import *
from common import setup_module, with_seed, teardown
from mxnet.base import MXNetError
from common import setup_module, with_seed, teardown, assertRaises
import random
import warnings

Expand Down Expand Up @@ -2098,6 +2099,78 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, for
lambda l, r: l + r,
rhs_is_scalar=True, verbose=False, density=0.5)


@with_seed()
def test_batchnorm_fallback():
# same test as test_operator.test_batchnorm_training, but tests fallback logic of batchnorm
stype = 'row_sparse'
for shape in [(2, 3), (2, 3, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
beta = np.ones(s)
gamma[1] = 3
beta[0] = 3

rolling_mean = np.random.uniform(size=s)
rolling_std = np.random.uniform(size=s)

data = mx.symbol.Variable('data', stype=stype)
in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]

test = mx.symbol.BatchNorm(data, fix_gamma=True)
assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)

test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)

test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)

test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2)

# Test varying channel axis
dim = len(shape)
for chaxis in range(-dim, dim):
chaxis_true = chaxis
if chaxis < 0:
chaxis_true = dim + chaxis

shapex = shape

channel_count = shapex[chaxis_true]
data_tmp = np.random.normal(-0.1, 0.1, size=shapex)

gamma = np.ones(channel_count)
beta = np.ones(channel_count)
if channel_count > 1:
gamma[1] = 3
beta[0] = 3

in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
mx.nd.array(beta).tostype(stype)]

xrolling_mean = np.random.uniform(size=channel_count)
xrolling_std = np.random.uniform(size=channel_count)
xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
mx.nd.array(xrolling_std).tostype(stype)]

test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)

test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01)


@with_seed()
def test_mkldnn_sparse():
# This test is trying to create a race condition describedd in
Expand Down

0 comments on commit 5b4d528

Please sign in to comment.