From f4c4952a4ea58d2565160442afa3a32a401dc352 Mon Sep 17 00:00:00 2001 From: DominikaJedynak Date: Mon, 14 Feb 2022 14:47:07 +0100 Subject: [PATCH] [FEATURE] Fuse dequantize with convolution (#20816) * Added possibility to fuse dequantize with convolution * Sum post-op fix and tests * Review change * Sanity fix * Sanity fix * Review suggestions * Resolving conflicts --- CONTRIBUTORS.md | 1 + src/operator/nn/dnnl/dnnl_convolution-inl.h | 4 + src/operator/subgraph/dnnl/dnnl_conv.cc | 99 ++++++++++--------- .../dnnl/dnnl_post_quantize_property.h | 28 +++--- .../python/dnnl/subgraphs/subgraph_common.py | 3 +- .../dnnl/subgraphs/test_conv_subgraph.py | 39 ++++++++ 6 files changed, 115 insertions(+), 59 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 47b491d6b74e..ee8a3841eeec 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -298,6 +298,7 @@ List of Contributors * [Bartosz Kuncer](https://github.com/bartekkuncer) * [Maria Boerner](https://github.com/mariaboerner1987) * [Zhenghui Jin](https://github.com/barry-jin) +* [Dominika Jedynak](https://github.com/DominikaJedynak) Label Bot --------- diff --git a/src/operator/nn/dnnl/dnnl_convolution-inl.h b/src/operator/nn/dnnl/dnnl_convolution-inl.h index 0c48d0d9faa8..2bac446beb20 100644 --- a/src/operator/nn/dnnl/dnnl_convolution-inl.h +++ b/src/operator/nn/dnnl/dnnl_convolution-inl.h @@ -43,6 +43,7 @@ struct DNNLConvParam : public dmlc::Parameter { bool with_sum; bool with_postsum_act; bool quantized; + bool enable_float_output; bool dedup_sum; dmlc::optional min_calib_range; // min float value calculated from calibration dataset @@ -56,6 +57,9 @@ struct DNNLConvParam : public dmlc::Parameter { .set_default(false) .describe("Add post activation after sum"); DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization"); + DMLC_DECLARE_FIELD(enable_float_output) + .set_default(false) + .describe("Whether to enable float32 output"); DMLC_DECLARE_FIELD(dedup_sum).set_default(false).describe("deduplicated sum input"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc index 27f9f8a8b1f8..ccaabdd68969 100644 --- a/src/operator/subgraph/dnnl/dnnl_conv.cc +++ b/src/operator/subgraph/dnnl/dnnl_conv.cc @@ -130,17 +130,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, auto& dnnl_param = full_conv_param.dnnl_param; auto& conv_param = full_conv_param.conv_param; auto bn_param = param_.bn_param.get(); - size_t input_size = 2 + (conv_param.no_bias ? 0 : 1) + (dnnl_param.with_bn ? 4 : 0) + - (dnnl_param.with_sum ? 1 : 0) + - (dnnl_param.quantized ? 2 + (dnnl_param.with_sum ? 2 : 0) : 0); - // When dedup is on, in_data is used to calculate sum instead of in_sum - if (dnnl_param.dedup_sum) { - input_size -= 1; - if (dnnl_param.quantized) { - input_size -= 2; - } - } - CHECK_EQ(inputs.size(), input_size); + index_t idx = 0; auto in_data = idx++; @@ -164,11 +154,10 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, sum_max = inputs[idx++].data().dptr()[0]; } } - CHECK_EQ(input_size, idx); + CHECK_EQ(inputs.size(), idx); bool has_bias = dnnl_param.with_bn || !conv_param.no_bias; NDArray data = inputs[in_data]; NDArray output = dnnl_param.with_sum ? inputs[in_sum] : outputs[kOut]; - // Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed. if (dnnl_param.with_sum) { if (!initialized_) { @@ -183,9 +172,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, if (!inplace_) { auto in_dnnl_mem = inputs[in_sum].GetDNNLData(); auto out_dnnl_mem = outputs[kOut].GetDNNLData(); - if (outputs[kOut].dtype() == mshadow::kInt32) { + if (outputs[kOut].dtype() == mshadow::kInt32 || outputs[kOut].dtype() == mshadow::kFloat32) { const auto& mem_desc = in_dnnl_mem->get_desc(); - const auto this_dtype = get_dnnl_type(mshadow::kInt32); + const auto this_dtype = get_dnnl_type(outputs[kOut].dtype()); auto omd = mem_desc; omd.data.data_type = static_cast(this_dtype); dnnl_mem_ptr tmp_mem( @@ -265,6 +254,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, post_requantize_ = true; weight_channelwise_scale = true; } + if (dnnl_param.enable_float_output) { + weight_channelwise_scale = true; + } data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_); DNNL_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, { weight_scales_ = GetWeightScales(cached_weight_, @@ -279,13 +271,24 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, if (dnnl_param.with_sum) { sum_in_scale = GetQuantizeScale(inputs[in_sum].dtype(), cached_sum_min_, cached_sum_max_); } - if (post_requantize_) { - output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8, - cached_output_min_, - cached_output_max_); + if (post_requantize_ || dnnl_param.enable_float_output) { + if (post_requantize_) { + output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8, + cached_output_min_, + cached_output_max_); + } else { + output_scale = 1.0; + } full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1); for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) { - full_conv_param.requantize_scales[c] = output_scale / data_scale_ / weight_scales_[c]; + full_conv_param.requantize_scales[c] = 1.0 / data_scale_ / weight_scales_[c]; + } + if (dnnl_param.with_act) { + full_conv_param.act_param.scale = output_scale; + } else { + for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) { + full_conv_param.requantize_scales[c] *= output_scale; + } } } else { Stream* s = ctx.get_stream(); @@ -322,14 +325,8 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, if (dnnl_param.with_sum) { LOG(ERROR) << "oneDNN doesn't support conv + relu + sum fusion yet."; full_conv_param.act_param.alpha *= output_scale; - } else { - // For conv+relu6 without sum, we don't need post_ops as output_scale can do the cut off. - dnnl_param.with_act = false; } } - if (dnnl_param.with_postsum_act) { - CHECK(full_conv_param.postsum_act_param.alg == dnnl::algorithm::eltwise_relu); - } } fwd_.reset(new DNNLConvForward(full_conv_param, ctx.is_train, @@ -389,7 +386,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx, DNNLConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, req, {output}); } - if (dnnl_param.quantized) { + if (dnnl_param.quantized && !dnnl_param.enable_float_output) { *outputs[kMin].data().dptr() = cached_output_min_; *outputs[kMax].data().dptr() = cached_output_max_; } @@ -521,10 +518,12 @@ static std::vector SgDNNLConvListInputNames(const NodeAttrs& attrs) static std::vector SgDNNLConvListOutputNames(const NodeAttrs& attrs) { auto const& param = nnvm::get(attrs.parsed); - if (param.full_conv_param.dnnl_param.quantized) + if (param.full_conv_param.dnnl_param.quantized && + !param.full_conv_param.dnnl_param.enable_float_output) { return std::vector{"output", "output_min", "output_max"}; - else + } else { return std::vector{"output"}; + } } static OpStatePtr CreateSgDNNLConvState(const nnvm::NodeAttrs& attrs, @@ -581,8 +580,10 @@ static bool SgDNNLConvInferShape(const nnvm::NodeAttrs& attrs, } } out_shapes->at(0) = base_out_shapes[0]; - SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); - SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); + if (!param.full_conv_param.dnnl_param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1)); + SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1)); + } return result; } else { return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); @@ -628,19 +629,24 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs& attrs, in_types->at(i) = base_in_types[base_idx++]; } } - if (param.full_conv_param.dnnl_param.min_calib_range.has_value() && - param.full_conv_param.dnnl_param.max_calib_range.has_value()) { - if (IsOutputUInt8(param)) { - TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + + if (param.full_conv_param.dnnl_param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); + } else { + if (param.full_conv_param.dnnl_param.min_calib_range.has_value() && + param.full_conv_param.dnnl_param.max_calib_range.has_value()) { + if (IsOutputUInt8(param)) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8); + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + } } else { - TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); } - } else { - TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); - } - TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); - TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); + } return result; } else { return DefaultSubgraphOpType(attrs, in_types, out_types); @@ -674,8 +680,10 @@ static bool SgDNNLConvOpStorageType(const nnvm::NodeAttrs& attrs, } } out_stypes->at(0) = base_out_stypes[0]; - type_assign(&out_stypes->at(1), mxnet::kDefaultStorage); - type_assign(&out_stypes->at(2), mxnet::kDefaultStorage); + if (!param.full_conv_param.dnnl_param.enable_float_output) { + type_assign(&out_stypes->at(1), mxnet::kDefaultStorage); + type_assign(&out_stypes->at(2), mxnet::kDefaultStorage); + } return result; } else { return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes); @@ -736,7 +744,10 @@ NNVM_REGISTER_OP(_sg_onednn_conv) .set_num_inputs(SgDNNLConvNumInputs) .set_num_outputs([](const NodeAttrs& attrs) { auto const& param = nnvm::get(attrs.parsed); - return param.full_conv_param.dnnl_param.quantized ? 3 : 1; + return param.full_conv_param.dnnl_param.quantized && + !param.full_conv_param.dnnl_param.enable_float_output ? + 3 : + 1; }) .set_attr_parser(SgDNNLConvParamParser) .set_attr("FListInputNames", SgDNNLConvListInputNames) diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h index 1aa52ca43ad7..456a0d10399e 100644 --- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h @@ -37,17 +37,15 @@ namespace mxnet { namespace op { namespace { -const std::set support_req_fusion_op = {"_contrib_quantized_elemwise_add", - "_contrib_quantized_elemwise_mul", - "_contrib_quantized_npi_add", - "_sg_onednn_conv", - "_sg_onednn_fully_connected", - "_sg_onednn_selfatt_qk", - "_sg_onednn_selfatt_valatt", - "_sg_onednn_batch_dot"}; - -const std::set no_enable_float_output = {Op::Get("_contrib_quantized_elemwise_add"), - Op::Get("_sg_onednn_conv")}; +const std::set support_req_fusion_op = { + "_contrib_quantized_elemwise_add", + "_contrib_quantized_elemwise_mul", + // "_contrib_quantized_npi_add" - to be added later on + "_sg_onednn_conv", + "_sg_onednn_fully_connected", + "_sg_onednn_selfatt_qk", + "_sg_onednn_selfatt_valatt", + "_sg_onednn_batch_dot"}; } // namespace class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { @@ -113,8 +111,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.emplace_back(&new_node); status = SelectStatus::kRequantize; - if ((raw_node->op() == Op::Get("_sg_onednn_conv")) || - (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add"))) { + // For now there is no support for dequantize fusion for contrib_quantized_elemwise_add + // so with this operator we finish after finding requantize node: + if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) { status = SelectStatus::kSuccess; } return true; @@ -214,7 +213,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { // When only fused quantized operator and requantize, set min/max_cablib_range, // When fused quantized operator + requantize + dequantize, set dequantize flag to true. - if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) { + if ((dequantize_node != nullptr) && + (fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) { fuse_node->attrs.dict["enable_float_output"] = "True"; } else { fuse_node->attrs.dict["min_calib_range"] = diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py b/tests/python/dnnl/subgraphs/subgraph_common.py index be2adb9e2f03..1657a05fc10a 100644 --- a/tests/python/dnnl/subgraphs/subgraph_common.py +++ b/tests/python/dnnl/subgraphs/subgraph_common.py @@ -90,7 +90,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'): if k.find('_quantize') != -1: assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: - if quantized_op_name.startswith("quantized_sg_onednn_fully_connected") and 'enable_float_output' in v: + if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected") + or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 'enable_float_output' in v: continue assert 'min_calib_range' in v assert 'max_calib_range' in v diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py b/tests/python/dnnl/subgraphs/test_conv_subgraph.py index e7dac8f8be59..286f3371332d 100644 --- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py +++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py @@ -115,6 +115,45 @@ def forward(self, x): check_fusion(net, data_shape, attr, check_quantization=False) +@mx.util.use_np +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +def test_pos_conv_add3(no_bias, data_shape, out_type): + # conv + add fusion case 3 + class ConvAdd(nn.HybridBlock): + def __init__(self, use_bias, **kwargs): + super(ConvAdd, self).__init__(**kwargs) + self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias) + + def forward(self, x): + out = x + self.conv0(x) + return out + + net = ConvAdd(use_bias=True) + check_quantize(net, data_shape, out_type) + + +@mx.util.use_np +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +def test_pos_conv_add4(no_bias, data_shape, out_type): + # conv + add fusion case 4 + class ConvAdd(nn.HybridBlock): + def __init__(self, use_bias, **kwargs): + super(ConvAdd, self).__init__(**kwargs) + self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias) + self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias) + + def forward(self, x): + out = self.conv1(x + self.conv0(x)) + return out + + net = ConvAdd(use_bias=True) + check_quantize(net, data_shape, out_type) + + @mx.util.use_np @pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize('alg,quantize', [