Skip to content

Commit 2ccf8d0

Browse files
abhi-iyernarendasan
authored andcommittedJul 28, 2020
feat(//core/conversion/var): created ITensorOrFreeze() method, to replace functionality of Var::ITensor()
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): updates to some comments on the PR Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comment Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): Addressed PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> fix(): bug in test_serialization, need to fix Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Delete converters.h.orig delete .orig file Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Update activation.cpp addressing PR comments Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com> Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Delete converters.h.orig delete .orig file Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Update activation.cpp addressing PR comments Signed-off-by: Abhiram Iyer <abhirami@nvidia.com>
1 parent 362c932 commit 2ccf8d0

17 files changed

+167
-95
lines changed
 

‎core/conversion/converters/BUILD

+23-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,32 @@ config_setting(
77
}
88
)
99

10+
cc_library(
11+
name = "weights",
12+
hdrs = [
13+
"Weights.h"
14+
],
15+
srcs = [
16+
"Weights.cpp"
17+
],
18+
deps = [
19+
"@tensorrt//:nvinfer",
20+
"//core/util:prelude",
21+
"//core/conversion/conversionctx"
22+
] + select({
23+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
24+
"//conditions:default": ["@libtorch//:libtorch"],
25+
}),
26+
alwayslink = True,
27+
)
28+
1029
cc_library(
1130
name = "converters",
1231
hdrs = [
1332
"converters.h"
1433
],
1534
srcs = [
1635
"NodeConverterRegistry.cpp",
17-
"Weights.cpp",
1836
"impl/activation.cpp",
1937
"impl/batch_norm.cpp",
2038
"impl/concat.cpp",
@@ -51,5 +69,8 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
5169
pkg_tar(
5270
name = "include",
5371
package_dir = "core/conversion/converters/",
54-
srcs = ["converters.h"],
72+
srcs = [
73+
"converters.h",
74+
"Weights.h"
75+
],
5576
)

‎core/conversion/converters/Weights.cpp

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include "core/util/prelude.h"
2-
#include "core/conversion/converters/converters.h"
2+
#include "core/conversion/converters/Weights.h"
33

44
namespace trtorch {
55
namespace core {
66
namespace conversion {
77
namespace converters {
88

9-
109
Weights::Weights() {
1110
this->num_input_maps = 0;
1211
this->num_output_maps = 0;
@@ -18,20 +17,36 @@ Weights::Weights() {
1817
Weights::Weights(ConversionCtx* ctx, float val) {
1918
this->num_input_maps = 1;
2019
this->num_output_maps = 1;
20+
2121
this->data.type = nvinfer1::DataType::kFLOAT;
2222
float* buf = reinterpret_cast<float*>(malloc(1 * sizeof(float)));
2323
buf[0] = val;
2424
this->data.values = buf;
2525
this->data.count = 1;
2626
ctx->builder_resources.push_back(buf);
27-
this->kernel_shape.nbDims = 1;
28-
this->kernel_shape.d[0] = 1;
27+
28+
this->shape.nbDims = 0;
29+
this->kernel_shape.nbDims = 0;
30+
}
31+
32+
Weights::Weights(ConversionCtx* ctx, int32_t val) {
33+
this->num_input_maps = 1;
34+
this->num_output_maps = 1;
35+
36+
this->data.type = nvinfer1::DataType::kINT32;
37+
int32_t* buf = reinterpret_cast<int32_t*>(malloc(1 * sizeof(int32_t)));
38+
buf[0] = val;
39+
this->data.values = buf;
40+
this->data.count = 1;
41+
ctx->builder_resources.push_back(buf);
42+
43+
this->shape.nbDims = 0;
44+
this->kernel_shape.nbDims = 0;
2945
}
3046

3147
Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
3248
if (t.sizes().size() > nvinfer1::Dims::MAX_DIMS) {
33-
//TODO: Handle this with exceptions or whatever
34-
LOG_INTERNAL_ERROR("The tensor requested to be converted to nvinfer1::Weights exceeds the max number of dimensions for TensorRT");
49+
TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights exceeds the max number of dimensions for TensorRT");
3550
}
3651
this->shape = util::toDims(t.sizes());
3752
if (t.sizes().size() >= 2) {
@@ -59,9 +74,7 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
5974
t_cpu = t_cpu.contiguous();
6075
auto dtype_optional = util::toTRTDataType(t_cpu.dtype());
6176
if (!dtype_optional) {
62-
//TODO: Handle this with exceptions or whatever
63-
//TODO: Implement handling for the Torch Types
64-
LOG_INTERNAL_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type");
77+
TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type");
6578
}
6679

6780
// Store the data in the conversion context so it remains until building is complete

‎core/conversion/converters/Weights.h

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include "core/util/prelude.h"
4+
#include "core/conversion/conversionctx/ConversionCtx.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace conversion {
9+
namespace converters {
10+
11+
struct Weights {
12+
nvinfer1::Weights data;
13+
nvinfer1::Dims kernel_shape;
14+
nvinfer1::Dims shape;
15+
int64_t num_input_maps;
16+
int64_t num_output_maps;
17+
18+
Weights();
19+
Weights(ConversionCtx* ctx, at::Tensor t);
20+
Weights(ConversionCtx* ctx, float val);
21+
Weights(ConversionCtx* ctx, int32_t val);
22+
friend std::ostream& operator<<(std::ostream& os, const Weights& w);
23+
};
24+
25+
inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
26+
auto t_weights = Weights(ctx, t);
27+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
28+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");
29+
30+
auto out = const_layer->getOutput(0);
31+
32+
std::ostringstream tensor_id;
33+
tensor_id << reinterpret_cast<int*>(out);
34+
35+
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
36+
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
37+
38+
return out;
39+
}
40+
41+
42+
} // namespace converters
43+
} // namespace conversion
44+
} // namespace core
45+
} // namespace trtorch

‎core/conversion/converters/converters.h

+1-22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/util/prelude.h"
1010
#include "core/conversion/var/Var.h"
1111
#include "core/conversion/conversionctx/ConversionCtx.h"
12+
#include "core/conversion/converters/Weights.h"
1213

1314
namespace trtorch {
1415
namespace core {
@@ -39,28 +40,6 @@ class RegisterNodeConversionPatterns {
3940
bool node_is_convertable(const torch::jit::Node* n);
4041
OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature);
4142

42-
struct Weights {
43-
//TODO: Rebuild this in a way that makes sense for more than just conv2/3D and linear
44-
nvinfer1::Weights data;
45-
nvinfer1::Dims kernel_shape;
46-
nvinfer1::Dims shape;
47-
int64_t num_input_maps;
48-
int64_t num_output_maps;
49-
50-
Weights();
51-
Weights(ConversionCtx* ctx, at::Tensor t);
52-
Weights(ConversionCtx* ctx, float val);
53-
friend std::ostream& operator<<(std::ostream& os, const Weights& w);
54-
};
55-
56-
inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
57-
auto t_weights = Weights(ctx, t);
58-
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
59-
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");
60-
const_layer->setName("[Freeze Tensor]");
61-
return const_layer->getOutput(0);
62-
}
63-
6443
} // namespace converters
6544
} // namespace conversion
6645
} // namespace core

‎core/conversion/converters/impl/activation.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace {
1010

1111
#define convert(act, trt_type) \
1212
bool act(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
13-
auto in = args[0].ITensor(); \
13+
auto in = args[0].ITensorOrFreeze(ctx); \
1414
\
1515
auto new_layer = \
1616
ctx->net->addActivation(*in, nvinfer1::ActivationType::trt_type); \
@@ -46,7 +46,7 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
4646
.pattern({
4747
"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)",
4848
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
49-
auto in = args[0].ITensor();
49+
auto in = args[0].ITensorOrFreeze(ctx);
5050
auto min = args[1].unwrapToDouble();
5151
auto max = args[2].unwrapToDouble();
5252

@@ -66,7 +66,7 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
6666
//TODO: Remove after functionalization
6767
"aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))",
6868
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
69-
auto in = args[0].ITensor();
69+
auto in = args[0].ITensorOrFreeze(ctx);
7070
auto min = args[1].unwrapToDouble();
7171
auto max = args[2].unwrapToDouble();
7272

‎core/conversion/converters/impl/batch_norm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1515
Tensor? mean, Tensor? var,
1616
bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG",
1717
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18-
auto input = args[0].ITensor();
18+
auto input = args[0].ITensor(); // assumes non-static input Tensor
1919
auto orig_shape = input->getDimensions();
2020
auto shape = util::toVec(orig_shape);
2121
auto options = torch::TensorOptions().dtype(torch::kFloat32);

‎core/conversion/converters/impl/conv_deconv.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1717
int[] output_padding, int groups, bool benchmark,
1818
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
1919
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20-
auto in = args[0].ITensor();
20+
auto in = args[0].ITensor(); // assumes non-static input Tensor
2121

2222
auto w = Weights(ctx, args[1].unwrapToTensor());
2323
auto stride = util::toDims(args[3].unwrapToIntList());

‎core/conversion/converters/impl/element_wise.cpp

+14-15
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
2626
self = self_shuffle->getOutput(0);
2727
}
2828

29-
3029
nvinfer1::ILayer* ele;
3130
if (scalar != 1) {
3231
LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform");
@@ -73,8 +72,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
7372
"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor",
7473
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
7574
// Should implement self + alpha * other
76-
auto self = args[0].ITensor();
77-
auto other = args[1].ITensor();
75+
auto self = args[0].ITensorOrFreeze(ctx);
76+
auto other = args[1].ITensorOrFreeze(ctx);
7877
auto scalar = args[2].unwrapToScalar().to<float>();
7978
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);
8079

@@ -90,8 +89,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
9089
"aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))",
9190
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
9291
// Should implement self + alpha * other
93-
auto self = args[0].ITensor();
94-
auto other = args[1].ITensor();
92+
auto self = args[0].ITensorOrFreeze(ctx);
93+
auto other = args[1].ITensorOrFreeze(ctx);
9594
auto scalar = args[2].unwrapToScalar().to<float>();
9695
auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar);
9796

@@ -107,8 +106,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
107106
"aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor",
108107
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
109108
// Should implement self - alpha * other
110-
auto self = args[0].ITensor();
111-
auto other = args[1].ITensor();
109+
auto self = args[0].ITensorOrFreeze(ctx);
110+
auto other = args[1].ITensorOrFreeze(ctx);
112111
auto scalar = args[2].unwrapToScalar().to<float>();
113112
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar);
114113

@@ -124,8 +123,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
124123
"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
125124
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
126125
// Should implement self / other
127-
auto self = args[0].ITensor();
128-
auto other = args[1].ITensor();
126+
auto self = args[0].ITensorOrFreeze(ctx);
127+
auto other = args[1].ITensorOrFreeze(ctx);
129128
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
130129

131130
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
@@ -140,8 +139,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
140139
"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
141140
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
142141
// TODO: Remove with functionalization
143-
auto self = args[0].ITensor();
144-
auto other = args[1].ITensor();
142+
auto self = args[0].ITensorOrFreeze(ctx);
143+
auto other = args[1].ITensorOrFreeze(ctx);
145144
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
146145

147146
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
@@ -156,8 +155,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
156155
"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
157156
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
158157
// Should implement self * other
159-
auto self = args[0].ITensor();
160-
auto other = args[1].ITensor();
158+
auto self = args[0].ITensorOrFreeze(ctx);
159+
auto other = args[1].ITensorOrFreeze(ctx);
161160
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
162161

163162
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
@@ -172,8 +171,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
172171
"aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
173172
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
174173
// TODO: Remove with functionalization
175-
auto self = args[0].ITensor();
176-
auto other = args[1].ITensor();
174+
auto self = args[0].ITensorOrFreeze(ctx);
175+
auto other = args[1].ITensorOrFreeze(ctx);
177176
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
178177

179178
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

‎core/conversion/converters/impl/linear.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1515
// PyTorch follows in: Nx*xIN, W: OUTxIN, B: OUT, out: Nx*xOUT
1616
// TensorRT inserts a flatten in when following conv
17-
auto in = args[0].ITensor();
17+
auto in = args[0].ITensorOrFreeze(ctx);
1818
auto shape = util::toVec(in->getDimensions());
1919

2020
LOG_DEBUG("Input tensor shape: " << in->getDimensions());

‎core/conversion/converters/impl/matrix_multiply.cpp

+2-22
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,10 @@ auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15-
nvinfer1::ITensor* self;
16-
if (args[0].isIValue()) {
17-
auto t = args[0].unwrapToTensor();
18-
auto t_weights = Weights(ctx, t);
19-
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
20-
TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n);
21-
const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str());
22-
self = const_layer->getOutput(0);
23-
} else {
24-
self = args[0].ITensor();
25-
}
15+
auto self = args[0].ITensorOrFreeze(ctx);
2616
LOG_DEBUG("self tensor shape: " << self->getDimensions());
2717

28-
nvinfer1::ITensor* other;
29-
if (args[1].isIValue()) {
30-
auto t = args[1].unwrapToTensor();
31-
auto t_weights = Weights(ctx, t);
32-
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
33-
TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n);
34-
const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str());
35-
other = const_layer->getOutput(0);
36-
} else {
37-
other = args[1].ITensor();
38-
}
18+
auto other = args[1].ITensorOrFreeze(ctx);
3919
LOG_DEBUG("other tensor shape: " << other->getDimensions());
4020

4121
auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);

0 commit comments

Comments
 (0)