Skip to content

Commit

Permalink
feat(aten::floor): Adds floor.int evaluator
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 21, 2021
1 parent e5a6468 commit a6a46e5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
16 changes: 13 additions & 3 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,21 @@ auto aten_registrations TRTORCH_UNUSED =
})})
.evaluator({c10::Symbol::fromQualString("aten::floor"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto el = args.at(n->input(0)).unwrapToDouble();

return static_cast<int64_t>(std::floor(el));
if (args.at(n->input(0)).IValue()->isInt()) {
auto el = args.at(n->input(0)).unwrapToInt();
return static_cast<int64_t>(std::floor(el));
} else if (args.at(n->input(0)).IValue()->isDouble()) {
auto el = args.at(n->input(0)).unwrapToDouble();
return static_cast<int64_t>(std::floor(el));
} else {
TRTORCH_THROW_ERROR(
"Unimplemented data type for aten::floor evaluator: "
<< args.at(n->input(0)).IValue()->type()->str());
return {};
}
},
EvalOptions().validSchemas({
"aten::floor.int(int a) -> (int)",
"aten::floor.float(float a) -> (int)",
})})
.evaluator({c10::Symbol::fromQualString("aten::warn"),
Expand Down
32 changes: 32 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,36 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, FloorIntIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : int = aten::floor(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.3]()
%2 : int = aten::floor(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit a6a46e5

Please sign in to comment.