diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp index 3386608f0d..451e77238e 100644 --- a/core/lowering/passes/remove_unnecessary_casts.cpp +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -138,6 +138,48 @@ void RemoveSingleUse0DTensors(std::shared_ptr& g) { user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); user->destroy(); break; + case c10::aten::div: + // If the first two entries to aten::div are non-Tensors, + // there cannot be a rounding mode specified (3rd entry) + if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) && + !user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) && + user->inputs().size() == 3 && + user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) && + torch::jit::toIValue(user->inputs()[2]).has_value()) { + // Select the first 2 entries of the inputs, corresponding to the values + auto div_args = user->inputs().slice(0, 2); + + // Depending on the rounding mode, create the appropriate nodes + if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") { + // Truncate case (round result towards 0) + torch::jit::Node* new_node_div; + // Create node which simply divides the two entries + new_node_div = g->create(c10::aten::div, div_args, 1); + new_node_div->insertAfter(user); + new_node_div->outputs()[0]->setType(c10::FloatType::get()); + + // Create node which casts the result to an integer, effectively truncating + new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1); + new_node->insertAfter(new_node_div); + new_node->outputs()[0]->setType(c10::IntType::get()); + + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + + } else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") { + // Floor case (round result down) + // Replace aten::div with aten::floordiv + new_node = g->create(c10::aten::floordiv, div_args, 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + } + } + default: new_node = g->create(user->kind(), user->inputs(), 1); new_node->insertAfter(user); diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp index dc4c397148..704b2064ea 100644 --- a/tests/core/lowering/test_remove_unnecessary_casts.cpp +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -5,6 +5,8 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/canonicalize.h" +#include "torch/csrc/jit/passes/constant_pooling.h" TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) { std::string source_graph = R"IR( @@ -255,6 +257,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) { ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor())); } +TEST(LoweringPasses, RemoveSingleUse0DTensorsDivTruncIntValuesAgree) { + // Ensure the source and target graphs have equivalent outputs + // (Source and Target are computing equivalent values) + std::string source_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %11: int = prim::Constant[value=-3]() + %234 : str = prim::Constant[value="trunc"]() + %3: Tensor = prim::NumToTensor(%0) + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::div(%1, %3, %234) + %50: int = aten::Int(%4) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %1: int = prim::Constant[value=-3]() + %40: float = aten::div(%1, %0) + %41: int = aten::Int(%40) + %4: Tensor = prim::NumToTensor(%41) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {}); + + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor())); + + // Ensure the lowering pass transforms the first graph into the second + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[8]]() + %3: Tensor = prim::NumToTensor(%0) + %234: str = prim::Constant[value="trunc"]() + %4: Tensor = aten::div(%3, %1, %234) + %5: int = aten::Int(%4) + return (%5))IR"; + + std::string target_graph = R"IR( + graph(%0 : int): + %1 : str = prim::Constant[value="trunc"]() + %2 : int = prim::Constant[value=8]() + %3 : float = aten::div(%0, %2) + %4 : int = aten::Int(%3) + return (%4))IR"; + + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsDivFloorIntValuesAgree) { + // Ensure the source and target graphs have equivalent outputs + // (Source and Target are computing equivalent values) + std::string source_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %11: int = prim::Constant[value=-3]() + %234 : str = prim::Constant[value="floor"]() + %3: Tensor = prim::NumToTensor(%0) + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::div(%1, %3, %234) + %50: int = aten::Int(%4) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %1: int = prim::Constant[value=-3]() + %40: int = aten::floordiv(%1, %0) + %41: int = aten::Int(%40) + %4: Tensor = prim::NumToTensor(%41) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {}); + + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor())); + + // Ensure the lowering pass transforms the first graph into the second + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[8]]() + %3: Tensor = prim::NumToTensor(%0) + %234: str = prim::Constant[value="floor"]() + %4: Tensor = aten::div(%3, %1, %234) + %5: int = aten::Int(%4) + return (%5))IR"; + + std::string target_graph = R"IR( + graph(%0 : int): + %1 : str = prim::Constant[value="floor"]() + %2 : int = prim::Constant[value=8]() + %3 : int = aten::floordiv(%0, %2) + return (%3))IR"; + + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); +} + TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) { std::string source_graph_no_inputs = R"IR( graph():