Skip to content

Commit a6a46e5

Browse files
committed
feat(aten::floor): Adds floor.int evaluator
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent e5a6468 commit a6a46e5

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

core/conversion/evaluators/aten.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,21 @@ auto aten_registrations TRTORCH_UNUSED =
468468
})})
469469
.evaluator({c10::Symbol::fromQualString("aten::floor"),
470470
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
471-
auto el = args.at(n->input(0)).unwrapToDouble();
472-
473-
return static_cast<int64_t>(std::floor(el));
471+
if (args.at(n->input(0)).IValue()->isInt()) {
472+
auto el = args.at(n->input(0)).unwrapToInt();
473+
return static_cast<int64_t>(std::floor(el));
474+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
475+
auto el = args.at(n->input(0)).unwrapToDouble();
476+
return static_cast<int64_t>(std::floor(el));
477+
} else {
478+
TRTORCH_THROW_ERROR(
479+
"Unimplemented data type for aten::floor evaluator: "
480+
<< args.at(n->input(0)).IValue()->type()->str());
481+
return {};
482+
}
474483
},
475484
EvalOptions().validSchemas({
485+
"aten::floor.int(int a) -> (int)",
476486
"aten::floor.float(float a) -> (int)",
477487
})})
478488
.evaluator({c10::Symbol::fromQualString("aten::warn"),

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,36 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
178178
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
179179
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
180180
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
181+
}
182+
183+
TEST(Evaluators, FloorIntIntEvaluatesCorrectly) {
184+
const auto graph = R"IR(
185+
graph():
186+
%1 : int = prim::Constant[value=9]()
187+
%2 : int = aten::floor(%1)
188+
return (%2))IR";
189+
190+
auto g = std::make_shared<torch::jit::Graph>();
191+
torch::jit::parseIR(graph, g.get());
192+
193+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
194+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
195+
196+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
197+
}
198+
199+
TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
200+
const auto graph = R"IR(
201+
graph():
202+
%1 : float = prim::Constant[value=9.3]()
203+
%2 : int = aten::floor(%1)
204+
return (%2))IR";
205+
206+
auto g = std::make_shared<torch::jit::Graph>();
207+
torch::jit::parseIR(graph, g.get());
208+
209+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
210+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
211+
212+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
181213
}

0 commit comments

Comments
 (0)