@@ -8,18 +8,22 @@ namespace converters {
8
8
namespace impl {
9
9
namespace {
10
10
11
- nvinfer1::ILayer* add_elementwise (ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1 ) {
11
+ nvinfer1::ILayer* add_elementwise (ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1 ) {
12
12
auto self_dims = self->getDimensions ();
13
+ auto self_dims_vec = util::toVec (self_dims);
13
14
auto other_dims = other->getDimensions ();
15
+ auto other_dims_vec = util::toVec (other_dims);
16
+ auto other_batch = other_dims_vec[0 ];
14
17
15
- TRTORCH_CHECK (util::volume (self_dims) == util::volume (other_dims), " Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
18
+ // TODO: Proper broadcast check
19
+ TRTORCH_CHECK (util::volume (self_dims) == util::volume (other_dims) || util::volume (self_dims) == util::volume (other_dims) / other_batch, " Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims);
16
20
17
21
if (self_dims != other_dims) {
18
22
LOG_DEBUG (" Input shape dont match inserting shuffle layers to reshape to " << self_dims);
19
- auto other_shuffle = ctx->net ->addShuffle (*other );
20
- other_shuffle ->setReshapeDimensions (self_dims );
21
- other_shuffle ->setName (std::string (" [Reshape other to " + util::toStr (self_dims) + ' ] ' ).c_str ());
22
- other = other_shuffle ->getOutput (0 );
23
+ auto self_shuffle = ctx->net ->addShuffle (*self );
24
+ self_shuffle ->setReshapeDimensions (util::toDimsPad (self_dims_vec, other_dims_vec. size ()) );
25
+ self_shuffle ->setName (std::string (" [Reshape self to " + util::toStr (self_dims) + " for broadcasting ( " + name + " )] " ).c_str ());
26
+ self = self_shuffle ->getOutput (0 );
23
27
}
24
28
25
29
@@ -72,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
72
76
auto self = args[0 ].ITensor ();
73
77
auto other = args[1 ].ITensor ();
74
78
auto scalar = args[2 ].unwrapToScalar ().to <float >();
75
- auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, scalar);
79
+ auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
76
80
77
81
TRTORCH_CHECK (add, " Unable to create add layer from node: " << *n);
78
82
@@ -89,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
89
93
auto self = args[0 ].ITensor ();
90
94
auto other = args[1 ].ITensor ();
91
95
auto scalar = args[2 ].unwrapToScalar ().to <float >();
92
- auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, scalar);
96
+ auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
93
97
94
98
TRTORCH_CHECK (add, " Unable to create add layer from node: " << *n);
95
99
@@ -106,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
106
110
auto self = args[0 ].ITensor ();
107
111
auto other = args[1 ].ITensor ();
108
112
auto scalar = args[2 ].unwrapToScalar ().to <float >();
109
- auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, other, scalar);
113
+ auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, other, util::node_info (n), scalar);
110
114
111
115
TRTORCH_CHECK (sub, " Unable to create sub layer from node: " << *n);
112
116
@@ -122,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
122
126
// Should implement self / other
123
127
auto self = args[0 ].ITensor ();
124
128
auto other = args[1 ].ITensor ();
125
- auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other);
129
+ auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n) );
126
130
127
131
TRTORCH_CHECK (div , " Unable to create div layer from node: " << *n);
128
132
@@ -138,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
138
142
// TODO: Remove with functionalization
139
143
auto self = args[0 ].ITensor ();
140
144
auto other = args[1 ].ITensor ();
141
- auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other);
145
+ auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n) );
142
146
143
147
TRTORCH_CHECK (div , " Unable to create div layer from node: " << *n);
144
148
@@ -154,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
154
158
// Should implement self * other
155
159
auto self = args[0 ].ITensor ();
156
160
auto other = args[1 ].ITensor ();
157
- auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other);
161
+ auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n) );
158
162
159
163
TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
160
164
@@ -170,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
170
174
// TODO: Remove with functionalization
171
175
auto self = args[0 ].ITensor ();
172
176
auto other = args[1 ].ITensor ();
173
- auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other);
177
+ auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n) );
174
178
175
179
TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
176
180
0 commit comments