Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Fuse dequantize with convolution (#20816)
Browse files Browse the repository at this point in the history
* Added possibility to fuse dequantize with convolution

* Sum post-op fix and tests

* Review change

* Sanity fix

* Sanity fix

* Review suggestions

* Resolving conflicts
  • Loading branch information
DominikaJedynak authored Feb 14, 2022
1 parent ff4c14f commit f4c4952
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 59 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ List of Contributors
* [Bartosz Kuncer](https://github.com/bartekkuncer)
* [Maria Boerner](https://github.com/mariaboerner1987)
* [Zhenghui Jin](https://github.com/barry-jin)
* [Dominika Jedynak](https://github.com/DominikaJedynak)

Label Bot
---------
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/dnnl/dnnl_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
bool with_sum;
bool with_postsum_act;
bool quantized;
bool enable_float_output;
bool dedup_sum;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
Expand All @@ -56,6 +57,9 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
.set_default(false)
.describe("Add post activation after sum");
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(enable_float_output)
.set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(dedup_sum).set_default(false).describe("deduplicated sum input");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
Expand Down
99 changes: 55 additions & 44 deletions src/operator/subgraph/dnnl/dnnl_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
auto& dnnl_param = full_conv_param.dnnl_param;
auto& conv_param = full_conv_param.conv_param;
auto bn_param = param_.bn_param.get();
size_t input_size = 2 + (conv_param.no_bias ? 0 : 1) + (dnnl_param.with_bn ? 4 : 0) +
(dnnl_param.with_sum ? 1 : 0) +
(dnnl_param.quantized ? 2 + (dnnl_param.with_sum ? 2 : 0) : 0);
// When dedup is on, in_data is used to calculate sum instead of in_sum
if (dnnl_param.dedup_sum) {
input_size -= 1;
if (dnnl_param.quantized) {
input_size -= 2;
}
}
CHECK_EQ(inputs.size(), input_size);

index_t idx = 0;

auto in_data = idx++;
Expand All @@ -164,11 +154,10 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
sum_max = inputs[idx++].data().dptr<float>()[0];
}
}
CHECK_EQ(input_size, idx);
CHECK_EQ(inputs.size(), idx);
bool has_bias = dnnl_param.with_bn || !conv_param.no_bias;
NDArray data = inputs[in_data];
NDArray output = dnnl_param.with_sum ? inputs[in_sum] : outputs[kOut];

// Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
if (dnnl_param.with_sum) {
if (!initialized_) {
Expand All @@ -183,9 +172,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (!inplace_) {
auto in_dnnl_mem = inputs[in_sum].GetDNNLData();
auto out_dnnl_mem = outputs[kOut].GetDNNLData();
if (outputs[kOut].dtype() == mshadow::kInt32) {
if (outputs[kOut].dtype() == mshadow::kInt32 || outputs[kOut].dtype() == mshadow::kFloat32) {
const auto& mem_desc = in_dnnl_mem->get_desc();
const auto this_dtype = get_dnnl_type(mshadow::kInt32);
const auto this_dtype = get_dnnl_type(outputs[kOut].dtype());
auto omd = mem_desc;
omd.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
dnnl_mem_ptr tmp_mem(
Expand Down Expand Up @@ -265,6 +254,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
post_requantize_ = true;
weight_channelwise_scale = true;
}
if (dnnl_param.enable_float_output) {
weight_channelwise_scale = true;
}
data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
DNNL_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
weight_scales_ = GetWeightScales<DType>(cached_weight_,
Expand All @@ -279,13 +271,24 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (dnnl_param.with_sum) {
sum_in_scale = GetQuantizeScale(inputs[in_sum].dtype(), cached_sum_min_, cached_sum_max_);
}
if (post_requantize_) {
output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8,
cached_output_min_,
cached_output_max_);
if (post_requantize_ || dnnl_param.enable_float_output) {
if (post_requantize_) {
output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8,
cached_output_min_,
cached_output_max_);
} else {
output_scale = 1.0;
}
full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1);
for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) {
full_conv_param.requantize_scales[c] = output_scale / data_scale_ / weight_scales_[c];
full_conv_param.requantize_scales[c] = 1.0 / data_scale_ / weight_scales_[c];
}
if (dnnl_param.with_act) {
full_conv_param.act_param.scale = output_scale;
} else {
for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) {
full_conv_param.requantize_scales[c] *= output_scale;
}
}
} else {
Stream<cpu>* s = ctx.get_stream<cpu>();
Expand Down Expand Up @@ -322,14 +325,8 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (dnnl_param.with_sum) {
LOG(ERROR) << "oneDNN doesn't support conv + relu + sum fusion yet.";
full_conv_param.act_param.alpha *= output_scale;
} else {
// For conv+relu6 without sum, we don't need post_ops as output_scale can do the cut off.
dnnl_param.with_act = false;
}
}
if (dnnl_param.with_postsum_act) {
CHECK(full_conv_param.postsum_act_param.alg == dnnl::algorithm::eltwise_relu);
}
}
fwd_.reset(new DNNLConvForward(full_conv_param,
ctx.is_train,
Expand Down Expand Up @@ -389,7 +386,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
DNNLConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, req, {output});
}

if (dnnl_param.quantized) {
if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
*outputs[kMin].data().dptr<float>() = cached_output_min_;
*outputs[kMax].data().dptr<float>() = cached_output_max_;
}
Expand Down Expand Up @@ -521,10 +518,12 @@ static std::vector<std::string> SgDNNLConvListInputNames(const NodeAttrs& attrs)

static std::vector<std::string> SgDNNLConvListOutputNames(const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
if (param.full_conv_param.dnnl_param.quantized)
if (param.full_conv_param.dnnl_param.quantized &&
!param.full_conv_param.dnnl_param.enable_float_output) {
return std::vector<std::string>{"output", "output_min", "output_max"};
else
} else {
return std::vector<std::string>{"output"};
}
}

static OpStatePtr CreateSgDNNLConvState(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -581,8 +580,10 @@ static bool SgDNNLConvInferShape(const nnvm::NodeAttrs& attrs,
}
}
out_shapes->at(0) = base_out_shapes[0];
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
if (!param.full_conv_param.dnnl_param.enable_float_output) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
return result;
} else {
return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
Expand Down Expand Up @@ -628,19 +629,24 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs& attrs,
in_types->at(i) = base_in_types[base_idx++];
}
}
if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
if (IsOutputUInt8(param)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);

if (param.full_conv_param.dnnl_param.enable_float_output) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else {
if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
if (IsOutputUInt8(param)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
}
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
}
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
}

TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
}
return result;
} else {
return DefaultSubgraphOpType(attrs, in_types, out_types);
Expand Down Expand Up @@ -674,8 +680,10 @@ static bool SgDNNLConvOpStorageType(const nnvm::NodeAttrs& attrs,
}
}
out_stypes->at(0) = base_out_stypes[0];
type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
if (!param.full_conv_param.dnnl_param.enable_float_output) {
type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
}
return result;
} else {
return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes);
Expand Down Expand Up @@ -736,7 +744,10 @@ NNVM_REGISTER_OP(_sg_onednn_conv)
.set_num_inputs(SgDNNLConvNumInputs)
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
return param.full_conv_param.dnnl_param.quantized ? 3 : 1;
return param.full_conv_param.dnnl_param.quantized &&
!param.full_conv_param.dnnl_param.enable_float_output ?
3 :
1;
})
.set_attr_parser(SgDNNLConvParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgDNNLConvListInputNames)
Expand Down
28 changes: 14 additions & 14 deletions src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,15 @@
namespace mxnet {
namespace op {
namespace {
const std::set<std::string> support_req_fusion_op = {"_contrib_quantized_elemwise_add",
"_contrib_quantized_elemwise_mul",
"_contrib_quantized_npi_add",
"_sg_onednn_conv",
"_sg_onednn_fully_connected",
"_sg_onednn_selfatt_qk",
"_sg_onednn_selfatt_valatt",
"_sg_onednn_batch_dot"};

const std::set<const Op*> no_enable_float_output = {Op::Get("_contrib_quantized_elemwise_add"),
Op::Get("_sg_onednn_conv")};
const std::set<std::string> support_req_fusion_op = {
"_contrib_quantized_elemwise_add",
"_contrib_quantized_elemwise_mul",
// "_contrib_quantized_npi_add" - to be added later on
"_sg_onednn_conv",
"_sg_onednn_fully_connected",
"_sg_onednn_selfatt_qk",
"_sg_onednn_selfatt_valatt",
"_sg_onednn_batch_dot"};
} // namespace

class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
Expand Down Expand Up @@ -113,8 +111,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
status = SelectStatus::kRequantize;
if ((raw_node->op() == Op::Get("_sg_onednn_conv")) ||
(raw_node->op() == Op::Get("_contrib_quantized_elemwise_add"))) {
// For now there is no support for dequantize fusion for contrib_quantized_elemwise_add
// so with this operator we finish after finding requantize node:
if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) {
status = SelectStatus::kSuccess;
}
return true;
Expand Down Expand Up @@ -214,7 +213,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {

// When only fused quantized operator and requantize, set min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize flag to true.
if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) {
if ((dequantize_node != nullptr) &&
(fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) {
fuse_node->attrs.dict["enable_float_output"] = "True";
} else {
fuse_node->attrs.dict["min_calib_range"] =
Expand Down
3 changes: 2 additions & 1 deletion tests/python/dnnl/subgraphs/subgraph_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
if k.find(quantized_op_name) != -1:
if quantized_op_name.startswith("quantized_sg_onednn_fully_connected") and 'enable_float_output' in v:
if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 'enable_float_output' in v:
continue
assert 'min_calib_range' in v
assert 'max_calib_range' in v
Expand Down
39 changes: 39 additions & 0 deletions tests/python/dnnl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,45 @@ def forward(self, x):
check_fusion(net, data_shape, attr, check_quantization=False)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
def test_pos_conv_add3(no_bias, data_shape, out_type):
# conv + add fusion case 3
class ConvAdd(nn.HybridBlock):
def __init__(self, use_bias, **kwargs):
super(ConvAdd, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias)

def forward(self, x):
out = x + self.conv0(x)
return out

net = ConvAdd(use_bias=True)
check_quantize(net, data_shape, out_type)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
def test_pos_conv_add4(no_bias, data_shape, out_type):
# conv + add fusion case 4
class ConvAdd(nn.HybridBlock):
def __init__(self, use_bias, **kwargs):
super(ConvAdd, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias)
self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)

def forward(self, x):
out = self.conv1(x + self.conv0(x))
return out

net = ConvAdd(use_bias=True)
check_quantize(net, data_shape, out_type)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
Expand Down

0 comments on commit f4c4952

Please sign in to comment.