Skip to content

Commit

Permalink
fix: Repair invalid schema arising from lowering pass (#1786)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Mar 30, 2023
1 parent 6bd7e14 commit 149b2b2
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
42 changes: 42 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
case c10::aten::div:
// If the first two entries to aten::div are non-Tensors,
// there cannot be a rounding mode specified (3rd entry)
if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) &&
!user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) &&
user->inputs().size() == 3 &&
user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) &&
torch::jit::toIValue(user->inputs()[2]).has_value()) {
// Select the first 2 entries of the inputs, corresponding to the values
auto div_args = user->inputs().slice(0, 2);

// Depending on the rounding mode, create the appropriate nodes
if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") {
// Truncate case (round result towards 0)
torch::jit::Node* new_node_div;
// Create node which simply divides the two entries
new_node_div = g->create(c10::aten::div, div_args, 1);
new_node_div->insertAfter(user);
new_node_div->outputs()[0]->setType(c10::FloatType::get());

// Create node which casts the result to an integer, effectively truncating
new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1);
new_node->insertAfter(new_node_div);
new_node->outputs()[0]->setType(c10::IntType::get());

user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;

} else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") {
// Floor case (round result down)
// Replace aten::div with aten::floordiv
new_node = g->create(c10::aten::floordiv, div_args, 1);
new_node->insertAfter(user);
new_node->outputs()[0]->setType(c10::IntType::get());

user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
user->destroy();
break;
}
}

default:
new_node = g->create(user->kind(), user->inputs(), 1);
new_node->insertAfter(user);
Expand Down
151 changes: 151 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/constant_pooling.h"

TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
std::string source_graph = R"IR(
Expand Down Expand Up @@ -255,6 +257,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsDivTruncIntValuesAgree) {
// Ensure the source and target graphs have equivalent outputs
// (Source and Target are computing equivalent values)
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=-3]()
%234 : str = prim::Constant[value="trunc"]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::div(%1, %3, %234)
%50: int = aten::Int(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=-3]()
%40: float = aten::div(%1, %0)
%41: int = aten::Int(%40)
%4: Tensor = prim::NumToTensor(%41)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));

// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%234: str = prim::Constant[value="trunc"]()
%4: Tensor = aten::div(%3, %1, %234)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : str = prim::Constant[value="trunc"]()
%2 : int = prim::Constant[value=8]()
%3 : float = aten::div(%0, %2)
%4 : int = aten::Int(%3)
return (%4))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

auto first_op = *(sg->block()->nodes().begin());
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsDivFloorIntValuesAgree) {
// Ensure the source and target graphs have equivalent outputs
// (Source and Target are computing equivalent values)
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=-3]()
%234 : str = prim::Constant[value="floor"]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::div(%1, %3, %234)
%50: int = aten::Int(%4)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=-3]()
%40: int = aten::floordiv(%1, %0)
%41: int = aten::Int(%40)
%4: Tensor = prim::NumToTensor(%41)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));

// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%234: str = prim::Constant[value="floor"]()
%4: Tensor = aten::div(%3, %1, %234)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : str = prim::Constant[value="floor"]()
%2 : int = prim::Constant[value=8]()
%3 : int = aten::floordiv(%0, %2)
return (%3))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

auto first_op = *(sg->block()->nodes().begin());
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
r->copyMetadata(first_op->output());
r->setType(c10::TensorType::get());
first_op->output()->replaceAllUsesWith(r);
first_op->destroy();

torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
Expand Down

0 comments on commit 149b2b2

Please sign in to comment.