diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 0e5c810eb8..94728d77ab 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -468,11 +468,21 @@ auto aten_registrations TRTORCH_UNUSED = })}) .evaluator({c10::Symbol::fromQualString("aten::floor"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto el = args.at(n->input(0)).unwrapToDouble(); - - return static_cast(std::floor(el)); + if (args.at(n->input(0)).IValue()->isInt()) { + auto el = args.at(n->input(0)).unwrapToInt(); + return static_cast(std::floor(el)); + } else if (args.at(n->input(0)).IValue()->isDouble()) { + auto el = args.at(n->input(0)).unwrapToDouble(); + return static_cast(std::floor(el)); + } else { + TRTORCH_THROW_ERROR( + "Unimplemented data type for aten::floor evaluator: " + << args.at(n->input(0)).IValue()->type()->str()); + return {}; + } }, EvalOptions().validSchemas({ + "aten::floor.int(int a) -> (int)", "aten::floor.float(float a) -> (int)", })}) .evaluator({c10::Symbol::fromQualString("aten::warn"), diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index a7bdf36c4b..e3158c9b98 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -178,4 +178,36 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) { auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6)); +} + +TEST(Evaluators, FloorIntIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=9]() + %2 : int = aten::floor(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : float = prim::Constant[value=9.3]() + %2 : int = aten::floor(%1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); } \ No newline at end of file