Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Fuse dequantize with convolution #20816

Merged
merged 7 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ List of Contributors
* [Bartosz Kuncer](https://github.com/bartekkuncer)
* [Maria Boerner](https://github.com/mariaboerner1987)
* [Zhenghui Jin](https://github.com/barry-jin)
* [Dominika Jedynak](https://github.com/DominikaJedynak)

Label Bot
---------
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/dnnl/dnnl_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
bool with_sum;
bool with_postsum_act;
bool quantized;
bool enable_float_output;
bool dedup_sum;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
Expand All @@ -56,6 +57,9 @@ struct DNNLConvParam : public dmlc::Parameter<DNNLConvParam> {
.set_default(false)
.describe("Add post activation after sum");
DMLC_DECLARE_FIELD(quantized).set_default(false).describe("enable quantization");
DMLC_DECLARE_FIELD(enable_float_output)
.set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(dedup_sum).set_default(false).describe("deduplicated sum input");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
Expand Down
99 changes: 55 additions & 44 deletions src/operator/subgraph/dnnl/dnnl_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
auto& dnnl_param = full_conv_param.dnnl_param;
auto& conv_param = full_conv_param.conv_param;
auto bn_param = param_.bn_param.get();
size_t input_size = 2 + (conv_param.no_bias ? 0 : 1) + (dnnl_param.with_bn ? 4 : 0) +
(dnnl_param.with_sum ? 1 : 0) +
(dnnl_param.quantized ? 2 + (dnnl_param.with_sum ? 2 : 0) : 0);
// When dedup is on, in_data is used to calculate sum instead of in_sum
if (dnnl_param.dedup_sum) {
input_size -= 1;
if (dnnl_param.quantized) {
input_size -= 2;
}
}
CHECK_EQ(inputs.size(), input_size);

index_t idx = 0;

auto in_data = idx++;
Expand All @@ -164,11 +154,10 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
sum_max = inputs[idx++].data().dptr<float>()[0];
}
}
CHECK_EQ(input_size, idx);
CHECK_EQ(inputs.size(), idx);
bool has_bias = dnnl_param.with_bn || !conv_param.no_bias;
NDArray data = inputs[in_data];
NDArray output = dnnl_param.with_sum ? inputs[in_sum] : outputs[kOut];

// Copy inputs[in_sum] into outputs[kOut] in case inplace optimization failed.
if (dnnl_param.with_sum) {
if (!initialized_) {
Expand All @@ -183,9 +172,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (!inplace_) {
auto in_dnnl_mem = inputs[in_sum].GetDNNLData();
auto out_dnnl_mem = outputs[kOut].GetDNNLData();
if (outputs[kOut].dtype() == mshadow::kInt32) {
if (outputs[kOut].dtype() == mshadow::kInt32 || outputs[kOut].dtype() == mshadow::kFloat32) {
const auto& mem_desc = in_dnnl_mem->get_desc();
const auto this_dtype = get_dnnl_type(mshadow::kInt32);
const auto this_dtype = get_dnnl_type(outputs[kOut].dtype());
auto omd = mem_desc;
omd.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
dnnl_mem_ptr tmp_mem(
Expand Down Expand Up @@ -265,6 +254,9 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
post_requantize_ = true;
weight_channelwise_scale = true;
}
if (dnnl_param.enable_float_output) {
weight_channelwise_scale = true;
DominikaJedynak marked this conversation as resolved.
Show resolved Hide resolved
}
data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
DNNL_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
weight_scales_ = GetWeightScales<DType>(cached_weight_,
Expand All @@ -279,13 +271,24 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (dnnl_param.with_sum) {
sum_in_scale = GetQuantizeScale(inputs[in_sum].dtype(), cached_sum_min_, cached_sum_max_);
}
if (post_requantize_) {
output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8,
cached_output_min_,
cached_output_max_);
if (post_requantize_ || dnnl_param.enable_float_output) {
if (post_requantize_) {
output_scale = GetQuantizeScale(IsOutputUInt8(param_) ? mshadow::kUint8 : mshadow::kInt8,
cached_output_min_,
cached_output_max_);
} else {
output_scale = 1.0;
}
full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1);
for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) {
full_conv_param.requantize_scales[c] = output_scale / data_scale_ / weight_scales_[c];
full_conv_param.requantize_scales[c] = 1.0 / data_scale_ / weight_scales_[c];
}
if (dnnl_param.with_act) {
full_conv_param.act_param.scale = output_scale;
} else {
for (size_t c = 0; c < full_conv_param.requantize_scales.size(); c++) {
full_conv_param.requantize_scales[c] *= output_scale;
}
}
} else {
Stream<cpu>* s = ctx.get_stream<cpu>();
Expand Down Expand Up @@ -322,14 +325,8 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
if (dnnl_param.with_sum) {
LOG(ERROR) << "oneDNN doesn't support conv + relu + sum fusion yet.";
full_conv_param.act_param.alpha *= output_scale;
} else {
// For conv+relu6 without sum, we don't need post_ops as output_scale can do the cut off.
dnnl_param.with_act = false;
}
}
if (dnnl_param.with_postsum_act) {
CHECK(full_conv_param.postsum_act_param.alg == dnnl::algorithm::eltwise_relu);
}
}
fwd_.reset(new DNNLConvForward(full_conv_param,
ctx.is_train,
Expand Down Expand Up @@ -389,7 +386,7 @@ void SgDNNLConvOperator::Forward(const OpContext& ctx,
DNNLConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, req, {output});
}

if (dnnl_param.quantized) {
if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
*outputs[kMin].data().dptr<float>() = cached_output_min_;
*outputs[kMax].data().dptr<float>() = cached_output_max_;
}
Expand Down Expand Up @@ -521,10 +518,12 @@ static std::vector<std::string> SgDNNLConvListInputNames(const NodeAttrs& attrs)

static std::vector<std::string> SgDNNLConvListOutputNames(const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
if (param.full_conv_param.dnnl_param.quantized)
if (param.full_conv_param.dnnl_param.quantized &&
!param.full_conv_param.dnnl_param.enable_float_output) {
return std::vector<std::string>{"output", "output_min", "output_max"};
else
} else {
return std::vector<std::string>{"output"};
}
}

static OpStatePtr CreateSgDNNLConvState(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -581,8 +580,10 @@ static bool SgDNNLConvInferShape(const nnvm::NodeAttrs& attrs,
}
}
out_shapes->at(0) = base_out_shapes[0];
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
if (!param.full_conv_param.dnnl_param.enable_float_output) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
return result;
} else {
return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
Expand Down Expand Up @@ -628,19 +629,24 @@ static bool SgDNNLConvInferType(const nnvm::NodeAttrs& attrs,
in_types->at(i) = base_in_types[base_idx++];
}
}
if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
if (IsOutputUInt8(param)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);

if (param.full_conv_param.dnnl_param.enable_float_output) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
} else {
if (param.full_conv_param.dnnl_param.min_calib_range.has_value() &&
param.full_conv_param.dnnl_param.max_calib_range.has_value()) {
if (IsOutputUInt8(param)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
}
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
}
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
}

TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
}
return result;
} else {
return DefaultSubgraphOpType(attrs, in_types, out_types);
Expand Down Expand Up @@ -674,8 +680,10 @@ static bool SgDNNLConvOpStorageType(const nnvm::NodeAttrs& attrs,
}
}
out_stypes->at(0) = base_out_stypes[0];
type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
if (!param.full_conv_param.dnnl_param.enable_float_output) {
type_assign(&out_stypes->at(1), mxnet::kDefaultStorage);
type_assign(&out_stypes->at(2), mxnet::kDefaultStorage);
}
return result;
} else {
return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes);
Expand Down Expand Up @@ -736,7 +744,10 @@ NNVM_REGISTER_OP(_sg_onednn_conv)
.set_num_inputs(SgDNNLConvNumInputs)
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLConvFusionParam>(attrs.parsed);
return param.full_conv_param.dnnl_param.quantized ? 3 : 1;
return param.full_conv_param.dnnl_param.quantized &&
!param.full_conv_param.dnnl_param.enable_float_output ?
3 :
1;
})
.set_attr_parser(SgDNNLConvParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgDNNLConvListInputNames)
Expand Down
28 changes: 14 additions & 14 deletions src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,15 @@
namespace mxnet {
namespace op {
namespace {
const std::set<std::string> support_req_fusion_op = {"_contrib_quantized_elemwise_add",
"_contrib_quantized_elemwise_mul",
"_contrib_quantized_npi_add",
"_sg_onednn_conv",
"_sg_onednn_fully_connected",
"_sg_onednn_selfatt_qk",
"_sg_onednn_selfatt_valatt",
"_sg_onednn_batch_dot"};

const std::set<const Op*> no_enable_float_output = {Op::Get("_contrib_quantized_elemwise_add"),
Op::Get("_sg_onednn_conv")};
const std::set<std::string> support_req_fusion_op = {
"_contrib_quantized_elemwise_add",
"_contrib_quantized_elemwise_mul",
// "_contrib_quantized_npi_add" - to be added later on
"_sg_onednn_conv",
"_sg_onednn_fully_connected",
"_sg_onednn_selfatt_qk",
"_sg_onednn_selfatt_valatt",
"_sg_onednn_batch_dot"};
} // namespace

class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
Expand Down Expand Up @@ -113,8 +111,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
status = SelectStatus::kRequantize;
if ((raw_node->op() == Op::Get("_sg_onednn_conv")) ||
(raw_node->op() == Op::Get("_contrib_quantized_elemwise_add"))) {
// For now there is no support for dequantize fusion for contrib_quantized_elemwise_add
// so with this operator we finish after finding requantize node:
if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) {
status = SelectStatus::kSuccess;
}
return true;
Expand Down Expand Up @@ -214,7 +213,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {

// When only fused quantized operator and requantize, set min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize flag to true.
if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) {
if ((dequantize_node != nullptr) &&
(fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) {
fuse_node->attrs.dict["enable_float_output"] = "True";
} else {
fuse_node->attrs.dict["min_calib_range"] =
Expand Down
3 changes: 2 additions & 1 deletion tests/python/dnnl/subgraphs/subgraph_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
if k.find('_quantize') != -1:
assert v['out_type'] == out_type
if k.find(quantized_op_name) != -1:
if quantized_op_name.startswith("quantized_sg_onednn_fully_connected") and 'enable_float_output' in v:
if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 'enable_float_output' in v:
continue
assert 'min_calib_range' in v
assert 'max_calib_range' in v
Expand Down
39 changes: 39 additions & 0 deletions tests/python/dnnl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,45 @@ def forward(self, x):
check_fusion(net, data_shape, attr, check_quantization=False)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not uint8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following settings for other tests in this file, I do not test it as it is scenario which is not used.

def test_pos_conv_add3(no_bias, data_shape, out_type):
# conv + add fusion case 3
class ConvAdd(nn.HybridBlock):
def __init__(self, use_bias, **kwargs):
super(ConvAdd, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias)

def forward(self, x):
out = x + self.conv0(x)
return out

net = ConvAdd(use_bias=True)
check_quantize(net, data_shape, out_type)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

def test_pos_conv_add4(no_bias, data_shape, out_type):
# conv + add fusion case 4
class ConvAdd(nn.HybridBlock):
def __init__(self, use_bias, **kwargs):
super(ConvAdd, self).__init__(**kwargs)
self.conv0 = nn.Conv2D(channels=data_shape[1], kernel_size=(1, 1), strides=1, use_bias=use_bias)
self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)

def forward(self, x):
out = self.conv1(x + self.conv0(x))
return out

net = ConvAdd(use_bias=True)
check_quantize(net, data_shape, out_type)


Copy link
Contributor

@anko-intel anko-intel Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about test with convolution, activation and sum ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were already there ConvActAdd, ConvBNSumAct

@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('alg,quantize', [
Expand Down