@@ -203,6 +203,173 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
203
203
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
204
204
}
205
205
206
+ TEST (Converters, ATenConvTransposeConvertsCorrectly) {
207
+ const auto graph = R"IR(
208
+ graph(%0 : Tensor,
209
+ %1 : Float(8, 3, 3, 3),
210
+ %2 : Float(8)):
211
+ %3 : int = prim::Constant[value=1]()
212
+ %4 : int = prim::Constant[value=0]()
213
+ %5 : int = prim::Constant[value=1]()
214
+ %6 : int = prim::Constant[value=0]()
215
+ %7 : bool = prim::Constant[value=1]()
216
+ %8 : int[] = prim::ListConstruct(%3, %3)
217
+ %9 : int[] = prim::ListConstruct(%4, %4)
218
+ %10 : int[] = prim::ListConstruct(%5, %5)
219
+ %11 : int[] = prim::ListConstruct(%6, %6)
220
+ %12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
221
+ return (%12))IR" ;
222
+
223
+ auto g = std::make_shared<torch::jit::Graph>();
224
+ torch::jit::parseIR (graph, &*g);
225
+
226
+ auto in = at::randint (1 , 3 , {1 , 8 , 5 , 5 }, {at::kCUDA });
227
+ auto w = at::randint (1 , 3 , {8 , 3 , 3 , 3 }, {at::kCUDA });
228
+ auto b = at::randint (1 , 3 , {3 }, {at::kCUDA });
229
+
230
+ auto jit_in = at::clone (in);
231
+ auto jit_w = at::clone (w);
232
+ auto jit_b = at::clone (b);
233
+
234
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
235
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
236
+
237
+ auto trt_in = at::clone (in);
238
+ auto trt_w = at::clone (w);
239
+ auto trt_b = at::clone (b);
240
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
241
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
242
+
243
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
244
+
245
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
246
+ }
247
+
248
+ TEST (Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
249
+ const auto graph = R"IR(
250
+ graph(%0 : Tensor,
251
+ %1 : Float(4, 1, 3, 3)):
252
+ %2 : None = prim::Constant()
253
+ %3 : int = prim::Constant[value=1]()
254
+ %4 : int = prim::Constant[value=0]()
255
+ %5 : int = prim::Constant[value=1]()
256
+ %6 : int = prim::Constant[value=0]()
257
+ %7 : bool = prim::Constant[value=1]()
258
+ %8 : int[] = prim::ListConstruct(%3, %3)
259
+ %9 : int[] = prim::ListConstruct(%4, %4)
260
+ %10 : int[] = prim::ListConstruct(%5, %5)
261
+ %11 : int[] = prim::ListConstruct(%6, %6)
262
+ %12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
263
+ return (%12))IR" ;
264
+
265
+ auto g = std::make_shared<torch::jit::Graph>();
266
+ torch::jit::parseIR (graph, &*g);
267
+
268
+ auto in = at::randint (1 , 2 , {1 , 4 , 3 , 3 }, {at::kCUDA });
269
+ auto w = at::randint (1 , 2 , {4 , 1 , 2 , 2 }, {at::kCUDA });
270
+
271
+ auto jit_in = at::clone (in);
272
+ auto jit_w = at::clone (w);
273
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w});
274
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
275
+
276
+ auto trt_in = at::clone (in);
277
+ auto trt_w = at::clone (w);
278
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w});
279
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
280
+
281
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
282
+
283
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
284
+ }
285
+
286
+
287
+ TEST (Converters, ATenConvTransposeWithStrideConvertsCorrectly) {
288
+ const auto graph = R"IR(
289
+ graph(%0 : Tensor,
290
+ %1 : Float(4, 3, 3, 3),
291
+ %2 : Float(4)):
292
+ %3 : int = prim::Constant[value=3]()
293
+ %4 : int = prim::Constant[value=0]()
294
+ %5 : int = prim::Constant[value=1]()
295
+ %6 : int = prim::Constant[value=0]()
296
+ %7 : bool = prim::Constant[value=1]()
297
+ %8 : int[] = prim::ListConstruct(%3, %3)
298
+ %9 : int[] = prim::ListConstruct(%4, %4)
299
+ %10 : int[] = prim::ListConstruct(%5, %5)
300
+ %11 : int[] = prim::ListConstruct(%6, %6)
301
+ %12 : int = prim::Constant[value=1]()
302
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
303
+ return (%13))IR" ;
304
+
305
+ auto g = std::make_shared<torch::jit::Graph>();
306
+ torch::jit::parseIR (graph, &*g);
307
+
308
+ auto in = at::randint (1 , 10 , {1 , 4 , 9 , 9 }, {at::kCUDA });
309
+ auto w = at::randint (1 , 10 , {4 , 3 , 3 , 3 }, {at::kCUDA });
310
+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
311
+
312
+ auto jit_in = at::clone (in);
313
+ auto jit_w = at::clone (w);
314
+ auto jit_b = at::clone (b);
315
+
316
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
317
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
318
+
319
+ auto trt_in = at::clone (in);
320
+ auto trt_w = at::clone (w);
321
+ auto trt_b = at::clone (b);
322
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
323
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
324
+
325
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
326
+
327
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
328
+ }
329
+
330
+ TEST (Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
331
+ const auto graph = R"IR(
332
+ graph(%0 : Tensor,
333
+ %1 : Float(4, 3, 4, 4),
334
+ %2 : Float(4)):
335
+ %3 : int = prim::Constant[value=1]()
336
+ %4 : int = prim::Constant[value=2]()
337
+ %5 : int = prim::Constant[value=1]()
338
+ %6 : int = prim::Constant[value=0]()
339
+ %7 : bool = prim::Constant[value=1]()
340
+ %8 : int[] = prim::ListConstruct(%3, %3)
341
+ %9 : int[] = prim::ListConstruct(%4, %4)
342
+ %10 : int[] = prim::ListConstruct(%5, %5)
343
+ %11 : int[] = prim::ListConstruct(%6, %6)
344
+ %12 : int = prim::Constant[value=1]()
345
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
346
+ return (%13))IR" ;
347
+
348
+ auto g = std::make_shared<torch::jit::Graph>();
349
+ torch::jit::parseIR (graph, &*g);
350
+
351
+ auto in = at::randint (1 , 10 , {1 , 4 , 4 , 4 }, {at::kCUDA });
352
+ auto w = at::randint (1 , 10 , {4 , 3 , 2 , 2 }, {at::kCUDA });
353
+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
354
+
355
+ auto jit_in = at::clone (in);
356
+ auto jit_w = at::clone (w);
357
+ auto jit_b = at::clone (b);
358
+
359
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
360
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
361
+
362
+ auto trt_in = at::clone (in);
363
+ auto trt_w = at::clone (w);
364
+ auto trt_b = at::clone (b);
365
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
366
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
367
+
368
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
369
+
370
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
371
+ }
372
+
206
373
// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
207
374
// const auto graph = R"IR(
208
375
// graph(%0 : Tensor,
0 commit comments