Skip to content

Commit a9f33e4

Browse files
committed
fix(//core/conversion/converters/impl/element_wise): Fix broadcast
support Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 0548540 commit a9f33e4

File tree

5 files changed

+24
-19
lines changed

5 files changed

+24
-19
lines changed

core/conversion/converters/impl/element_wise.cpp

+17-13
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1) {
11+
nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1) {
1212
auto self_dims = self->getDimensions();
13+
auto self_dims_vec = util::toVec(self_dims);
1314
auto other_dims = other->getDimensions();
15+
auto other_dims_vec = util::toVec(other_dims);
16+
auto other_batch = other_dims_vec[0];
1417

15-
TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims), "Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
18+
// TODO: Proper broadcast check
19+
TRTORCH_CHECK(util::volume(self_dims) == util::volume(other_dims) || util::volume(self_dims) == util::volume(other_dims) / other_batch, "Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims);
1620

1721
if (self_dims != other_dims) {
1822
LOG_DEBUG("Input shape dont match inserting shuffle layers to reshape to " << self_dims);
19-
auto other_shuffle = ctx->net->addShuffle(*other);
20-
other_shuffle->setReshapeDimensions(self_dims);
21-
other_shuffle->setName(std::string("[Reshape other to " + util::toStr(self_dims) + ']').c_str());
22-
other = other_shuffle->getOutput(0);
23+
auto self_shuffle = ctx->net->addShuffle(*self);
24+
self_shuffle->setReshapeDimensions(util::toDimsPad(self_dims_vec, other_dims_vec.size()));
25+
self_shuffle->setName(std::string("[Reshape self to " + util::toStr(self_dims) + " for broadcasting (" + name + ")]").c_str());
26+
self = self_shuffle->getOutput(0);
2327
}
2428

2529

@@ -72,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
7276
auto self = args[0].ITensor();
7377
auto other = args[1].ITensor();
7478
auto scalar = args[2].unwrapToScalar().to<float>();
75-
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar);
79+
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);
7680

7781
TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);
7882

@@ -89,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
8993
auto self = args[0].ITensor();
9094
auto other = args[1].ITensor();
9195
auto scalar = args[2].unwrapToScalar().to<float>();
92-
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, scalar);
96+
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);
9397

9498
TRTORCH_CHECK(add, "Unable to create add layer from node: " << *n);
9599

@@ -106,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
106110
auto self = args[0].ITensor();
107111
auto other = args[1].ITensor();
108112
auto scalar = args[2].unwrapToScalar().to<float>();
109-
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, scalar);
113+
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar);
110114

111115
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
112116

@@ -122,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
122126
// Should implement self / other
123127
auto self = args[0].ITensor();
124128
auto other = args[1].ITensor();
125-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other);
129+
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
126130

127131
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
128132

@@ -138,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
138142
// TODO: Remove with functionalization
139143
auto self = args[0].ITensor();
140144
auto other = args[1].ITensor();
141-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other);
145+
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
142146

143147
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
144148

@@ -154,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
154158
// Should implement self * other
155159
auto self = args[0].ITensor();
156160
auto other = args[1].ITensor();
157-
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other);
161+
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
158162

159163
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
160164

@@ -170,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
170174
// TODO: Remove with functionalization
171175
auto self = args[0].ITensor();
172176
auto other = args[1].ITensor();
173-
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other);
177+
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
174178

175179
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
176180

tests/accuracy/accuracy_test.h

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AccuracyTests
2020
std::cerr << "error loading the model\n";
2121
return;
2222
}
23+
mod.eval();
2324
}
2425

2526
void TearDown() {

tests/accuracy/test_fp16_accuracy.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
2424
jit_total += targets.sizes()[0];
2525
jit_correct += torch::sum(torch::eq(predictions, targets));
2626
}
27-
torch::Tensor jit_accuracy = jit_correct / jit_total;
27+
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
2828

2929
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
3030
auto extra_info = trtorch::ExtraInfo({input_shape});
@@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
4545
trt_correct += torch::sum(torch::eq(predictions, targets));
4646
}
4747

48-
torch::Tensor trt_accuracy = trt_correct / trt_total;
48+
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;
4949

5050
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
5151
}

tests/accuracy/test_fp32_accuracy.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
2424
jit_total += targets.sizes()[0];
2525
jit_correct += torch::sum(torch::eq(predictions, targets));
2626
}
27-
torch::Tensor jit_accuracy = jit_correct / jit_total;
27+
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
2828

2929
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
3030
auto extra_info = trtorch::ExtraInfo({input_shape});
@@ -45,7 +45,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
4545
trt_correct += torch::sum(torch::eq(predictions, targets));
4646
}
4747

48-
torch::Tensor trt_accuracy = trt_correct / trt_total;
48+
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;
4949

5050
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
5151
}

tests/accuracy/test_int8_accuracy.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
5454
jit_total += targets.sizes()[0];
5555
jit_correct += torch::sum(torch::eq(predictions, targets));
5656
}
57-
torch::Tensor jit_accuracy = jit_correct / jit_total;
57+
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
5858

5959
// Compile Graph
6060
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
@@ -72,7 +72,7 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
7272
trt_total += targets.sizes()[0];
7373
trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
7474
}
75-
torch::Tensor trt_accuracy = trt_correct / trt_total;
75+
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;
7676

7777
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_accuracy, trt_accuracy, 3));
7878
}

0 commit comments

Comments
 (0)