@@ -26,7 +26,6 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
26
26
self = self_shuffle->getOutput (0 );
27
27
}
28
28
29
-
30
29
nvinfer1::ILayer* ele;
31
30
if (scalar != 1 ) {
32
31
LOG_WARNING (" Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform" );
@@ -73,8 +72,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
73
72
" aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor" ,
74
73
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
75
74
// Should implement self + alpha * other
76
- auto self = args[0 ].ITensor ( );
77
- auto other = args[1 ].ITensor ( );
75
+ auto self = args[0 ].ITensorOrFreeze (ctx );
76
+ auto other = args[1 ].ITensorOrFreeze (ctx );
78
77
auto scalar = args[2 ].unwrapToScalar ().to <float >();
79
78
auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
80
79
@@ -90,8 +89,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
90
89
" aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ,
91
90
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
92
91
// Should implement self + alpha * other
93
- auto self = args[0 ].ITensor ( );
94
- auto other = args[1 ].ITensor ( );
92
+ auto self = args[0 ].ITensorOrFreeze (ctx );
93
+ auto other = args[1 ].ITensorOrFreeze (ctx );
95
94
auto scalar = args[2 ].unwrapToScalar ().to <float >();
96
95
auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
97
96
@@ -107,8 +106,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
107
106
" aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor" ,
108
107
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
109
108
// Should implement self - alpha * other
110
- auto self = args[0 ].ITensor ( );
111
- auto other = args[1 ].ITensor ( );
109
+ auto self = args[0 ].ITensorOrFreeze (ctx );
110
+ auto other = args[1 ].ITensorOrFreeze (ctx );
112
111
auto scalar = args[2 ].unwrapToScalar ().to <float >();
113
112
auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, other, util::node_info (n), scalar);
114
113
@@ -124,8 +123,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
124
123
" aten::div.Tensor(Tensor self, Tensor other) -> Tensor" ,
125
124
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
126
125
// Should implement self / other
127
- auto self = args[0 ].ITensor ( );
128
- auto other = args[1 ].ITensor ( );
126
+ auto self = args[0 ].ITensorOrFreeze (ctx );
127
+ auto other = args[1 ].ITensorOrFreeze (ctx );
129
128
auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
130
129
131
130
TRTORCH_CHECK (div , " Unable to create div layer from node: " << *n);
@@ -140,8 +139,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
140
139
" aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)" ,
141
140
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
142
141
// TODO: Remove with functionalization
143
- auto self = args[0 ].ITensor ( );
144
- auto other = args[1 ].ITensor ( );
142
+ auto self = args[0 ].ITensorOrFreeze (ctx );
143
+ auto other = args[1 ].ITensorOrFreeze (ctx );
145
144
auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
146
145
147
146
TRTORCH_CHECK (div , " Unable to create div layer from node: " << *n);
@@ -156,8 +155,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
156
155
" aten::mul.Tensor(Tensor self, Tensor other) -> Tensor" ,
157
156
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
158
157
// Should implement self * other
159
- auto self = args[0 ].ITensor ( );
160
- auto other = args[1 ].ITensor ( );
158
+ auto self = args[0 ].ITensorOrFreeze (ctx );
159
+ auto other = args[1 ].ITensorOrFreeze (ctx );
161
160
auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n));
162
161
163
162
TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
@@ -172,8 +171,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
172
171
" aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)" ,
173
172
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
174
173
// TODO: Remove with functionalization
175
- auto self = args[0 ].ITensor ( );
176
- auto other = args[1 ].ITensor ( );
174
+ auto self = args[0 ].ITensorOrFreeze (ctx );
175
+ auto other = args[1 ].ITensorOrFreeze (ctx );
177
176
auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n));
178
177
179
178
TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
0 commit comments