Skip to content

Commit cafcced

Browse files
committedJun 17, 2020
fix(plugin): trying to fix bug in plugin
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
1 parent f0fefaa commit cafcced

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed
 

‎core/conversion/converters/impl/interpolate.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,26 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
118118

119119
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");
120120

121-
auto out_shape = in_shape;
121+
auto out_shape = in_shape;
122122
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
123123

124124
if (!align_corners) {
125-
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
126-
//auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
127-
auto creator = new plugins::InterpolatePluginCreator();
125+
//auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
126+
std::raise(SIGINT);
128127

129-
auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
128+
//auto creator_auto = getPluginRegistry()->getPluginCreator("interpolate", "1");
129+
//auto plugin_auto = creator_auto->createPlugin(util::node_info(n).c_str(), nullptr);
130130

131-
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(in), 1, *plugin);
131+
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
132+
133+
auto creator = new plugins::InterpolatePluginCreator();
134+
auto plugin = creator->createPlugin("interpolate_plugin", in_shape, out_shape, out_size, std::string("linear"), align_corners);
135+
136+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
137+
resize_layer->setName(util::node_info(n).c_str());
132138

133139
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
140+
134141
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
135142
} else {
136143
auto resize_layer = ctx->net->addResize(*in);

‎core/conversion/converters/impl/plugins/interpolate_plugin.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ namespace conversion {
2121
namespace converters {
2222
namespace impl {
2323
namespace plugins {
24-
namespace {
2524

2625
/*
2726
* InterpolatePlugin class implementations
@@ -64,6 +63,18 @@ InterpolatePlugin::InterpolatePlugin(const char *data, size_t length) {
6463
}
6564
}
6665

66+
std::vector<int64_t> InterpolatePlugin::getInputShape() {
67+
return in_shape;
68+
}
69+
70+
std::vector<int64_t> InterpolatePlugin::getOutputShape() {
71+
return out_shape;
72+
}
73+
74+
std::vector<int64_t> InterpolatePlugin::getOutputSize() {
75+
return size;
76+
}
77+
6778
int InterpolatePlugin::getNbOutputs() const {
6879
return 1;
6980
}
@@ -206,7 +217,7 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
206217
return nullptr;
207218
}
208219

209-
nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
220+
nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
210221
name = name;
211222
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
212223
}
@@ -222,7 +233,6 @@ const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames()
222233

223234
REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
224235

225-
} // namespace
226236
} // namespace plugins
227237
} // namespace impl
228238
} // namespace converters

‎core/conversion/converters/impl/plugins/interpolate_plugin.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ namespace conversion {
2222
namespace converters {
2323
namespace impl {
2424
namespace plugins {
25-
namespace {
2625

2726
class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
2827
private:
@@ -52,6 +51,12 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
5251

5352
InterpolatePlugin() = delete;
5453

54+
std::vector<int64_t> getInputShape();
55+
56+
std::vector<int64_t> getOutputShape();
57+
58+
std::vector<int64_t> getOutputSize();
59+
5560
int getNbOutputs() const override;
5661

5762
const char* getPluginType() const override;
@@ -110,14 +115,13 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
110115

111116
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;
112117

113-
nvinfer1::IPluginV2* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
118+
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
114119

115120
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override;
116121

117122
const nvinfer1::PluginFieldCollection* getFieldNames() override;
118123
};
119124

120-
} // namespace
121125
} // namespace plugins
122126
} // namespace impl
123127
} // namespace converters

0 commit comments

Comments
 (0)