Skip to content

Commit f022dfe

Browse files
committed
feat(//cpp/api): Functional Dataloader based PTQ
- Couple assorted fixes in conversion implementation - Set up the space to have phase specific settings inside the compiler - PTQ Calibrator implementation moved to the public API, means Python will need its own but it probably did anyway - PTQ now works with dataloader and all the overrides for Calibration algorithm work - CIFAR10 Dataloader implementation - Application still has bugs in reporting accuracy and reading from calibration cache Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 676bf56 commit f022dfe

28 files changed

+758
-261
lines changed

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ py/.eggs
1818
._DS_Store
1919
*.pth
2020
*.pyc
21-
cpp/ptq/training/vgg16/data/
21+
cpp/ptq/training/vgg16/data/*
22+
*.bin
23+
cpp/ptq/datasets/data/
24+
._.DS_Store
25+
*.tar.gz

core/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ cc_library(
1616
"@libtorch//:libtorch",
1717
"@tensorrt//:nvinfer"
1818
],
19-
alwayslink=True,
19+
alwayslink=True,
2020
)
2121

2222

core/compiler.cpp

+21-19
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@
2424
namespace trtorch {
2525
namespace core {
2626

27-
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
27+
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
2828

2929
std::vector<c10::Argument> args;
3030
for (auto in : g->inputs()) {
3131
args.push_back(c10::Argument(in->debugName(), in->type()));
3232
}
33-
33+
3434
std::vector<c10::Argument> returns;
3535
for (auto out : g->outputs()) {
3636
returns.push_back(c10::Argument(out->debugName(), out->type()));
3737
}
38-
38+
3939
return c10::FunctionSchema(method_name, method_name, args, returns);
4040
}
4141

4242

4343
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
44-
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
44+
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
4545
auto schema = execution::GetEngineFunctionSchema(uid);
4646
auto num_io = execution::GetEngineIO(uid);
4747

@@ -53,14 +53,14 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
5353
in_val->setType(c10::TensorType::get());
5454
graph_inputs.push_back(in_val);
5555
}
56-
56+
5757
auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
5858
g->block()->appendNode(engine_node);
5959

6060
for (auto o : engine_node->outputs()) {
6161
g->registerOutput(o);
6262
}
63-
63+
6464
return;
6565
}
6666

@@ -69,48 +69,50 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
6969
auto g = mod.get_method(method_name).graph();
7070
// Go through PyTorch Lowering to simplify graph and extract weight parameters
7171
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
72-
72+
7373
g = graph_and_parameters.first;
74-
74+
7575
// Go through TRTorch Lowering to reformat graph to be conversion friendly
7676
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
7777
lowering::LowerGraph(g);
78-
78+
7979
auto params = graph_and_parameters.second;
8080
auto named_params = conversion::get_named_params(g->inputs(), params);
8181
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
82-
82+
8383
// Is this necessary?
8484
lowering::LowerBlock(g->block());
85-
85+
8686
return conversion::VerifyConverterSupportForBlock(g->block());
8787
}
8888

8989
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
9090
std::string method_name,
91-
conversion::ExtraInfo cfg) {
91+
ExtraInfo cfg) {
92+
auto convert_cfg = std::move(cfg.convert_info);
93+
9294
auto g = mod.get_method(method_name).graph();
9395
// Go through PyTorch Lowering to simplify graph and extract weight parameters
9496
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
95-
97+
9698
g = graph_and_parameters.first;
97-
99+
98100
// Go through TRTorch Lowering to reformat graph to be conversion friendly
99101
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
100102
lowering::LowerGraph(g);
101-
103+
102104
auto params = graph_and_parameters.second;
103105
auto named_params = conversion::get_named_params(g->inputs(), params);
104106
LOG_INFO(*g << "(CompileGraph)\n");
105-
107+
106108
// Is this necessary?
107109
lowering::LowerBlock(g->block());
108-
auto engine = ConvertBlockToEngine(g->block(), cfg, named_params);
110+
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
109111
return std::move(engine);
110112
}
111113

112114
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
113-
conversion::ExtraInfo cfg) {
115+
ExtraInfo cfg) {
114116
// TODO: Should be doing a functional transform but need PR #31978
115117
// [jit] More robust mangling
116118
// torch::jit::script::Module new_mod = mod.clone();
@@ -128,7 +130,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
128130

129131
return new_mod;
130132
}
131-
133+
132134
} // namespace core
133135
} // namespace trtorch
134136

core/compiler.h

+10-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,19 @@
66

77
namespace trtorch {
88
namespace core {
9+
10+
struct ExtraInfo {
11+
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
12+
: convert_info(std::move(input_ranges)) {}
13+
conversion::ConversionInfo convert_info;
14+
};
15+
916
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
1017

1118
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
12-
std::string method_name, conversion::ExtraInfo cfg);
19+
std::string method_name, ExtraInfo cfg);
1320

14-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);
21+
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
1522

1623
} // namespace core
17-
} // namespace trtorch
24+
} // namespace trtorch

core/conversion/conversion.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
179179
}
180180
}
181181

182-
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
182+
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
183183
LOG_INFO(ctx->logger, "Converting Block");
184184

185185
auto inputs = b->inputs();
@@ -221,7 +221,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI
221221
// a serialized TensorRT engine that can be deserialized and run
222222

223223
// Probably should consolidate these two functions
224-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
224+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
225225
ConversionCtx ctx(build_info.engine_settings);
226226
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
227227
std::string engine = ctx.SerializeEngine();

core/conversion/conversion.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ struct InputRange {
3030
std::vector<int64_t> max_shape);
3131
};
3232

33-
struct ExtraInfo {
33+
struct ConversionInfo {
3434
std::vector<InputRange> input_ranges;
3535
BuilderSettings engine_settings;
36-
ExtraInfo(std::vector<InputRange> input_ranges)
36+
ConversionInfo(std::vector<InputRange> input_ranges)
3737
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
3838
};
3939

@@ -43,7 +43,7 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect
4343

4444
// Converts a already lowered block (blocks with no sub blocks) to
4545
// a serialized TensorRT engine that can be deserialized and run
46-
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
46+
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params);
4747

4848
bool OpSupported(const torch::jit::Node* n);
4949

core/conversion/conversionctx/ConversionCtx.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
2020
<< "\n Max Workspace Size: " << s.workspace_size \
2121
<< "\n Device Type: " << s.device \
2222
<< "\n Engine Capability: " << s.capability \
23-
<< "\n Calibrator Created: " << s.calibrator ? true : false;
23+
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
2424
return os;
2525
}
2626

core/conversion/converters/impl/batch_norm.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
8383
auto gamma = args[1].unwrapToTensor();
8484

8585
if (/*training*/ args[5].unwrapToBool()) {
86-
LOG_WARNING("TensorRT only converts forward pass of graphs, but saw training = True, may see undefined behavior, consider placing module in eval mode");
86+
LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see
87+
unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN");
8788
}
8889

8990
// If gamma is None this fails

core/conversion/converters/impl/pooling.cpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,17 @@ auto pooling_registrations = RegisterNodeConversionPatterns()
7979
for (size_t i = 0; i < out_shape.size(); i++) {
8080
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i];
8181
}
82-
LOG_DEBUG("Stride" << util::toDims(stride));
82+
LOG_DEBUG("Stride: " << util::toDims(stride));
8383

8484
std::vector<int64_t> window(out_shape.size());
8585
for (size_t i = 0; i < out_shape.size(); i++) {
8686
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
8787
}
8888

89-
LOG_DEBUG("Window" << util::toDims(window));
89+
LOG_DEBUG("Window: " << util::toDims(window));
9090

9191
auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
92-
if (!new_layer) {
93-
LOG_ERROR("Unable to create average pooling layer from node: " << *n);
94-
return false;
95-
}
92+
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);
9693

9794
new_layer->setStrideNd(util::toDims(stride));
9895

core/quantization/BUILD

Whitespace-only changes.

core/quantization/TRTEntropyCalibrator.cpp

-64
This file was deleted.

core/quantization/quantization.h

-69
This file was deleted.

0 commit comments

Comments
 (0)