Skip to content

Commit ee727f8

Browse files
committed
fix(aten::_convolution): out channels was passed in incorrectly for
deconv Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent ce6cf75 commit ee727f8

File tree

4 files changed

+178
-11
lines changed

4 files changed

+178
-11
lines changed

core/conversion/converters/impl/conv_deconv.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,25 @@ auto conv_registrations = RegisterNodeConversionPatterns()
1818
auto in = args[0].ITensor();
1919

2020
auto w = Weights(ctx, args[1].unwrapToTensor());
21-
auto stride = util::toDimsHW(args[3].unwrapToIntList());
21+
auto stride = util::toDims(args[3].unwrapToIntList());
2222
LOG_DEBUG("stride: " << stride);
23-
auto padding = util::toDimsHW(args[4].unwrapToIntList());
23+
auto padding = util::toDims(args[4].unwrapToIntList());
2424
LOG_DEBUG("padding: " << padding);
25-
auto dilation = util::toDimsHW(args[5].unwrapToIntList());
25+
auto dilation = util::toDims(args[5].unwrapToIntList());
2626
LOG_DEBUG("dilation: " << dilation);
2727
bool transposed = args[6].unwrapToBool();
28-
auto out_padding = util::toDimsHW(args[7].unwrapToIntList());
28+
auto out_padding = util::toDims(args[7].unwrapToIntList());
2929
LOG_DEBUG("out_padding: " << out_padding);
3030
int64_t groups = args[8].unwrapToInt();
3131

3232
nvinfer1::ILayer* new_layer;
3333
if (transposed) {
34-
//TODO: Check deconv correctness
35-
LOG_WARNING(ctx->logger, "Deconvolution converter has not be tested");
3634
nvinfer1::IDeconvolutionLayer* deconv;
3735
if (args[2].IValue()->isTensor()) {
3836
Weights b(ctx, args[2].IValue()->toTensor());
39-
deconv = ctx->net->addDeconvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
37+
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, b.data);
4038
} else {
41-
deconv = ctx->net->addDeconvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, {});
39+
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, {});
4240
}
4341

4442
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);

tests/core/converters/BUILD

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ converter_test(
1616
)
1717

1818
converter_test(
19-
name = "test_conv"
19+
name = "test_conv_deconv"
2020
)
2121

2222
converter_test(
@@ -56,7 +56,7 @@ test_suite(
5656
tests = [
5757
":test_activation",
5858
":test_batch_norm",
59-
":test_conv",
59+
":test_conv_deconv",
6060
":test_element_wise",
6161
":test_linear",
6262
":test_matrix_multiply",

tests/core/converters/test_conv.cpp tests/core/converters/test_conv_deconv.cpp

+167
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,173 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
203203
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
204204
}
205205

206+
TEST(Converters, ATenConvTransposeConvertsCorrectly) {
207+
const auto graph = R"IR(
208+
graph(%0 : Tensor,
209+
%1 : Float(8, 3, 3, 3),
210+
%2 : Float(8)):
211+
%3 : int = prim::Constant[value=1]()
212+
%4 : int = prim::Constant[value=0]()
213+
%5 : int = prim::Constant[value=1]()
214+
%6 : int = prim::Constant[value=0]()
215+
%7 : bool = prim::Constant[value=1]()
216+
%8 : int[] = prim::ListConstruct(%3, %3)
217+
%9 : int[] = prim::ListConstruct(%4, %4)
218+
%10 : int[] = prim::ListConstruct(%5, %5)
219+
%11 : int[] = prim::ListConstruct(%6, %6)
220+
%12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
221+
return (%12))IR";
222+
223+
auto g = std::make_shared<torch::jit::Graph>();
224+
torch::jit::parseIR(graph, &*g);
225+
226+
auto in = at::randint(1, 3, {1, 8, 5, 5}, {at::kCUDA});
227+
auto w = at::randint(1, 3, {8, 3, 3, 3}, {at::kCUDA});
228+
auto b = at::randint(1, 3, {3}, {at::kCUDA});
229+
230+
auto jit_in = at::clone(in);
231+
auto jit_w = at::clone(w);
232+
auto jit_b = at::clone(b);
233+
234+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
235+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
236+
237+
auto trt_in = at::clone(in);
238+
auto trt_w = at::clone(w);
239+
auto trt_b = at::clone(b);
240+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
241+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
242+
243+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
244+
245+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
246+
}
247+
248+
TEST(Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
249+
const auto graph = R"IR(
250+
graph(%0 : Tensor,
251+
%1 : Float(4, 1, 3, 3)):
252+
%2 : None = prim::Constant()
253+
%3 : int = prim::Constant[value=1]()
254+
%4 : int = prim::Constant[value=0]()
255+
%5 : int = prim::Constant[value=1]()
256+
%6 : int = prim::Constant[value=0]()
257+
%7 : bool = prim::Constant[value=1]()
258+
%8 : int[] = prim::ListConstruct(%3, %3)
259+
%9 : int[] = prim::ListConstruct(%4, %4)
260+
%10 : int[] = prim::ListConstruct(%5, %5)
261+
%11 : int[] = prim::ListConstruct(%6, %6)
262+
%12 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %3, %7, %7, %7)
263+
return (%12))IR";
264+
265+
auto g = std::make_shared<torch::jit::Graph>();
266+
torch::jit::parseIR(graph, &*g);
267+
268+
auto in = at::randint(1, 2, {1, 4, 3, 3}, {at::kCUDA});
269+
auto w = at::randint(1, 2, {4, 1, 2, 2}, {at::kCUDA});
270+
271+
auto jit_in = at::clone(in);
272+
auto jit_w = at::clone(w);
273+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w});
274+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
275+
276+
auto trt_in = at::clone(in);
277+
auto trt_w = at::clone(w);
278+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w});
279+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
280+
281+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
282+
283+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
284+
}
285+
286+
287+
TEST(Converters, ATenConvTransposeWithStrideConvertsCorrectly) {
288+
const auto graph = R"IR(
289+
graph(%0 : Tensor,
290+
%1 : Float(4, 3, 3, 3),
291+
%2 : Float(4)):
292+
%3 : int = prim::Constant[value=3]()
293+
%4 : int = prim::Constant[value=0]()
294+
%5 : int = prim::Constant[value=1]()
295+
%6 : int = prim::Constant[value=0]()
296+
%7 : bool = prim::Constant[value=1]()
297+
%8 : int[] = prim::ListConstruct(%3, %3)
298+
%9 : int[] = prim::ListConstruct(%4, %4)
299+
%10 : int[] = prim::ListConstruct(%5, %5)
300+
%11 : int[] = prim::ListConstruct(%6, %6)
301+
%12 : int = prim::Constant[value=1]()
302+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
303+
return (%13))IR";
304+
305+
auto g = std::make_shared<torch::jit::Graph>();
306+
torch::jit::parseIR(graph, &*g);
307+
308+
auto in = at::randint(1, 10, {1, 4, 9, 9}, {at::kCUDA});
309+
auto w = at::randint(1, 10, {4, 3, 3, 3}, {at::kCUDA});
310+
auto b = at::randint(1, 10, {3}, {at::kCUDA});
311+
312+
auto jit_in = at::clone(in);
313+
auto jit_w = at::clone(w);
314+
auto jit_b = at::clone(b);
315+
316+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
317+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
318+
319+
auto trt_in = at::clone(in);
320+
auto trt_w = at::clone(w);
321+
auto trt_b = at::clone(b);
322+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
323+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
324+
325+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
326+
327+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
328+
}
329+
330+
TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
331+
const auto graph = R"IR(
332+
graph(%0 : Tensor,
333+
%1 : Float(4, 3, 4, 4),
334+
%2 : Float(4)):
335+
%3 : int = prim::Constant[value=1]()
336+
%4 : int = prim::Constant[value=2]()
337+
%5 : int = prim::Constant[value=1]()
338+
%6 : int = prim::Constant[value=0]()
339+
%7 : bool = prim::Constant[value=1]()
340+
%8 : int[] = prim::ListConstruct(%3, %3)
341+
%9 : int[] = prim::ListConstruct(%4, %4)
342+
%10 : int[] = prim::ListConstruct(%5, %5)
343+
%11 : int[] = prim::ListConstruct(%6, %6)
344+
%12 : int = prim::Constant[value=1]()
345+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
346+
return (%13))IR";
347+
348+
auto g = std::make_shared<torch::jit::Graph>();
349+
torch::jit::parseIR(graph, &*g);
350+
351+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
352+
auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA});
353+
auto b = at::randint(1, 10, {3}, {at::kCUDA});
354+
355+
auto jit_in = at::clone(in);
356+
auto jit_w = at::clone(w);
357+
auto jit_b = at::clone(b);
358+
359+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
360+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
361+
362+
auto trt_in = at::clone(in);
363+
auto trt_w = at::clone(w);
364+
auto trt_b = at::clone(b);
365+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
366+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
367+
368+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
369+
370+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
371+
}
372+
206373
// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
207374
// const auto graph = R"IR(
208375
// graph(%0 : Tensor,

tests/util/run_graph_engine.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ std::vector<at::Tensor> RunGraphEngine(std::shared_ptr<torch::jit::Graph>& g,
6565
std::vector<at::Tensor> inputs) {
6666
LOG_DEBUG("Running TRT version");
6767
auto in = toInputRanges(inputs);
68-
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), in, named_params);
68+
auto info = core::conversion::ConversionInfo(in);
69+
info.engine_settings.workspace_size = 1 << 20;
70+
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
6971
return RunEngine(eng, inputs);
7072
}
7173

0 commit comments

Comments
 (0)