Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Resolving conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
DominikaJedynak committed Feb 8, 2022
1 parent d165bb8 commit 07c80f4
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
status = SelectStatus::kRequantize;
if (no_enable_float_output.count(raw_node->op()) == 0) {
// For now there is no support for dequantize fusion for contrib_quantized_elemwise_add
// so with this operator we finish after finding requantize node:
if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) {
status = SelectStatus::kSuccess;
}
return true;
Expand Down Expand Up @@ -187,11 +189,9 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {

nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id = 0) const override {
nnvm::ObjectPtr fuse_node = nullptr;
nnvm::ObjectPtr requantize_node = nullptr;
nnvm::ObjectPtr dequantize_node = nullptr;
static const std::set<const Op*> no_enable_float_output = {
Op::Get("_contrib_quantized_elemwise_add")};
nnvm::ObjectPtr fuse_node = nullptr;
nnvm::ObjectPtr requantize_node = nullptr;
nnvm::ObjectPtr dequantize_node = nullptr;

DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable())
Expand All @@ -213,7 +213,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {

// When only fused quantized operator and requantize, set min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize flag to true.
if ((dequantize_node != nullptr) && (no_enable_float_output.count(fuse_node->op()) == 0)) {
if ((dequantize_node != nullptr) &&
(fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) {
fuse_node->attrs.dict["enable_float_output"] = "True";
} else {
fuse_node->attrs.dict["min_calib_range"] =
Expand Down

0 comments on commit 07c80f4

Please sign in to comment.