@@ -178,4 +178,36 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
178
178
auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
179
179
auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
180
180
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
181
+ }
182
+
183
+ TEST (Evaluators, FloorIntIntEvaluatesCorrectly) {
184
+ const auto graph = R"IR(
185
+ graph():
186
+ %1 : int = prim::Constant[value=9]()
187
+ %2 : int = aten::floor(%1)
188
+ return (%2))IR" ;
189
+
190
+ auto g = std::make_shared<torch::jit::Graph>();
191
+ torch::jit::parseIR (graph, g.get ());
192
+
193
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
194
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
195
+
196
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
197
+ }
198
+
199
+ TEST (Evaluators, FloorFloatIntEvaluatesCorrectly) {
200
+ const auto graph = R"IR(
201
+ graph():
202
+ %1 : float = prim::Constant[value=9.3]()
203
+ %2 : int = aten::floor(%1)
204
+ return (%2))IR" ;
205
+
206
+ auto g = std::make_shared<torch::jit::Graph>();
207
+ torch::jit::parseIR (graph, g.get ());
208
+
209
+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
210
+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
211
+
212
+ ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
181
213
}
0 commit comments