Skip to content

Commit

Permalink
feat: Add aten::type_as lowering pass
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Aug 10, 2021
1 parent 1f2ffc4 commit b57a6dd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
24 changes: 21 additions & 3 deletions core/lowering/passes/reduce_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,34 @@ void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph) {
graph(%x, %device, %dtype, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
return (%out))IR";
std::string to_general_pattern = R"IR(
std::string to_dtype_pattern = R"IR(
graph(%x, %device, %dtype, %nb, %copy, %format):
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
return (%out))IR";

std::string to_type_as_pattern = R"IR(
graph(%input, %other):
%out : Tensor = aten::type_as(%input, %other)
return (%out))IR";

std::string to_other_pattern = R"IR(
graph(%input, %other):
%5 : bool = prim::Constant[value=0]()
%6 : None = prim::Constant()
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
return (%out))IR";

// replace aten::to.device with aten::to.dtype
torch::jit::SubgraphRewriter map_aten_device_to_dtype;
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_general_pattern);
map_aten_device_to_dtype.RegisterRewritePattern(to_device_pattern, to_dtype_pattern);
map_aten_device_to_dtype.runOnGraph(graph);
LOG_GRAPH("Post lowering of aten::to.device -> " << *graph);

// replace aten::type_as with aten::to.other
torch::jit::SubgraphRewriter map_aten_type_as_to_other;
map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern);
map_aten_type_as_to_other.runOnGraph(graph);

LOG_GRAPH("Post lowering of [aten::to.device|aten::type_as] -> " << *graph);
}

} // namespace passes
Expand Down
29 changes: 29 additions & 0 deletions tests/core/conversion/converters/test_cast.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/torch.h>
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
Expand Down Expand Up @@ -133,3 +134,31 @@ TEST(Converters, ATenBoolToINT32TensorConvertsCorrectly) {

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenTypeAsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : int = prim::Constant[value=-1]()
%a : int = prim::Constant[value=1]()
%4 : Tensor = aten::add(%0, %2, %a)
%5 : Tensor = aten::gt(%1, %a)
%6 : Tensor = aten::type_as(%4, %5)
return (%6, %5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in1 = at::randint(1, 3, {3, 4, 3}, {at::kCUDA});
auto in2 = at::randint(1, 3, {3, 4, 3}, {at::kCUDA});
// Lower aten::type_as to aten::to.other
trtorch::core::lowering::passes::ReduceToOperation(g);

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});

params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}
23 changes: 23 additions & 0 deletions tests/core/lowering/test_reduce_to_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,26 @@ TEST(LoweringPasses, ReduceToCorrectly) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) {
std::string source_graph = R"IR(
graph(%input, %other):
%out : Tensor = aten::type_as(%input, %other)
return (%out))IR";
std::string target_graph = R"IR(
graph(%input, %other):
%5 : bool = prim::Constant[value=0]()
%6 : None = prim::Constant()
%out : Tensor = aten::to(%input, %other, %5, %5, %6)
return (%out))IR";

trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
trtorch::core::lowering::passes::ReduceToOperation(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit b57a6dd

Please sign in to comment.