From 07c80f4d29e7587181cea37a6fbd78407fc56917 Mon Sep 17 00:00:00 2001 From: Dominika Jedynak Date: Tue, 8 Feb 2022 20:22:15 +0100 Subject: [PATCH] Resolving conflicts --- .../subgraph/dnnl/dnnl_post_quantize_property.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h index 5bf56e5d779d..456a0d10399e 100644 --- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h @@ -111,7 +111,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.emplace_back(&new_node); status = SelectStatus::kRequantize; - if (no_enable_float_output.count(raw_node->op()) == 0) { + // For now there is no support for dequantize fusion for contrib_quantized_elemwise_add + // so with this operator we finish after finding requantize node: + if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) { status = SelectStatus::kSuccess; } return true; @@ -187,11 +189,9 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, const int subgraph_id = 0) const override { - nnvm::ObjectPtr fuse_node = nullptr; - nnvm::ObjectPtr requantize_node = nullptr; - nnvm::ObjectPtr dequantize_node = nullptr; - static const std::set no_enable_float_output = { - Op::Get("_contrib_quantized_elemwise_add")}; + nnvm::ObjectPtr fuse_node = nullptr; + nnvm::ObjectPtr requantize_node = nullptr; + nnvm::ObjectPtr dequantize_node = nullptr; DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { if (node->is_variable()) @@ -213,7 +213,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { // When only fused quantized operator and requantize, set min/max_cablib_range, // When fused quantized operator + requantize + dequantize, set dequantize flag to true. - if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) { + if ((dequantize_node != nullptr) && + (fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) { fuse_node->attrs.dict["enable_float_output"] = "True"; } else { fuse_node->attrs.dict["min_calib_range"] =