Skip to content

Commit 2ee11f0

Browse files
silehtmergify[bot]
authored andcommitted
feat: use DTO for NCNN init parameters
1 parent cbbbd99 commit 2ee11f0

File tree

3 files changed

+49
-46
lines changed

3 files changed

+49
-46
lines changed

src/apidata.h

+25
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ namespace dd
266266
*/
267267
void toJDoc(JDoc &jd) const;
268268

269+
/**
270+
* \brief converts APIData to oat++ DTO
271+
*/
272+
template <typename T> inline std::shared_ptr<T> createSharedDTO() const
273+
{
274+
rapidjson::Document d;
275+
d.SetObject();
276+
toJDoc(reinterpret_cast<JDoc &>(d));
277+
278+
rapidjson::StringBuffer buffer;
279+
rapidjson::Writer<rapidjson::StringBuffer, rapidjson::UTF8<>,
280+
rapidjson::UTF8<>, rapidjson::CrtAllocator,
281+
rapidjson::kWriteNanAndInfFlag>
282+
writer(buffer);
283+
bool done = d.Accept(writer);
284+
if (!done)
285+
throw DataConversionException("JSON rendering failed");
286+
287+
std::shared_ptr<oatpp::data::mapping::ObjectMapper> object_mapper
288+
= oatpp::parser::json::mapping::ObjectMapper::createShared();
289+
return object_mapper
290+
->readFromString<oatpp::Object<T>>(buffer.GetString())
291+
.getPtr();
292+
}
293+
269294
/**
270295
* \brief converts APIData to rapidjson JSON value
271296
* @param jd JSON Document hosting the destination JSON value

src/backends/ncnn/ncnnlib.cc

+17-39
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "outputconnectorstrategy.h"
2323
#include <thread>
2424
#include <algorithm>
25-
#include "utils/utils.hpp"
2625

2726
// NCNN
2827
#include "ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
5352
{
5453
this->_libname = "ncnn";
5554
_net = new ncnn::Net();
56-
_net->opt.num_threads = _threads;
55+
_net->opt.num_threads = 1;
5756
_net->opt.blob_allocator = &_blob_pool_allocator;
5857
_net->opt.workspace_allocator = &_workspace_pool_allocator;
59-
_net->opt.lightmode = _lightmode;
58+
_net->opt.lightmode = true;
6059
}
6160

6261
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -69,12 +68,9 @@ namespace dd
6968
this->_libname = "ncnn";
7069
_net = tl._net;
7170
tl._net = nullptr;
72-
_nclasses = tl._nclasses;
73-
_threads = tl._threads;
7471
_timeserie = tl._timeserie;
7572
_old_height = tl._old_height;
76-
_inputBlob = tl._inputBlob;
77-
_outputBlob = tl._outputBlob;
73+
_init_dto = tl._init_dto;
7874
}
7975

8076
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -94,6 +90,8 @@ namespace dd
9490
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
9591
TMLModel>::init_mllib(const APIData &ad)
9692
{
93+
_init_dto = ad.createSharedDTO<NcnnInitDto>();
94+
9795
bool use_fp32 = (ad.has("datatype")
9896
&& ad.get("datatype").get<std::string>()
9997
== "fp32"); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124122
_old_height = this->_inputc.height();
125123
_net->set_input_h(_old_height);
126124

127-
if (ad.has("nclasses"))
128-
_nclasses = ad.get("nclasses").get<int>();
129-
130-
if (ad.has("threads"))
131-
_threads = ad.get("threads").get<int>();
132-
else
133-
_threads = dd_utils::my_hardware_concurrency();
134-
135125
_timeserie = this->_inputc._timeserie;
136126
if (_timeserie)
137127
this->_mltype = "timeserie";
138128

139-
if (ad.has("lightmode"))
140-
{
141-
_lightmode = ad.get("lightmode").get<bool>();
142-
_net->opt.lightmode = _lightmode;
143-
}
144-
145-
// setting the value of Input Layer
146-
if (ad.has("inputblob"))
147-
{
148-
_inputBlob = ad.get("inputblob").get<std::string>();
149-
}
150-
// setting the final Output Layer
151-
if (ad.has("outputblob"))
152-
{
153-
_outputBlob = ad.get("outputblob").get<std::string>();
154-
}
155-
129+
_net->opt.lightmode = _init_dto->lightmode;
156130
_blob_pool_allocator.set_size_compare_ratio(0.0f);
157131
_workspace_pool_allocator.set_size_compare_ratio(0.5f);
158132
model_type(this->_mlmodel._params, this->_mltype);
@@ -233,7 +207,10 @@ namespace dd
233207

234208
// Extract detection or classification
235209
int ret = 0;
236-
std::string out_blob = _outputBlob;
210+
std::string out_blob;
211+
if (_init_dto->outputBlob != nullptr)
212+
out_blob = _init_dto->outputBlob->std_str();
213+
237214
if (out_blob.empty())
238215
{
239216
if (bbox == true)
@@ -262,11 +239,11 @@ namespace dd
262239
{
263240
best = ad_output.get("best").get<int>();
264241
}
265-
if (best == -1 || best > _nclasses)
266-
best = _nclasses;
242+
if (best == -1 || best > _init_dto->nclasses)
243+
best = _init_dto->nclasses;
267244

268245
// for loop around batch size
269-
#pragma omp parallel for num_threads(_threads)
246+
#pragma omp parallel for num_threads(*_init_dto->threads)
270247
for (size_t b = 0; b < inputc._ids.size(); b++)
271248
{
272249
std::vector<double> probs;
@@ -276,8 +253,8 @@ namespace dd
276253
APIData rad;
277254

278255
ncnn::Extractor ex = _net->create_extractor();
279-
ex.set_num_threads(_threads);
280-
ex.input(_inputBlob.c_str(), inputc._in.at(b));
256+
ex.set_num_threads(_init_dto->threads);
257+
ex.input(_init_dto->inputBlob->c_str(), inputc._in.at(b));
281258

282259
ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
283260
if (ret == -1)
@@ -423,7 +400,8 @@ namespace dd
423400
} // end for batch_size
424401

425402
tout.add_results(vrad);
426-
out.add("nclasses", this->_nclasses);
403+
int nclasses = this->_init_dto->nclasses;
404+
out.add("nclasses", nclasses);
427405
if (bbox == true)
428406
out.add("bbox", true);
429407
out.add("roi", false);

src/backends/ncnn/ncnnlib.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
#ifndef NCNNLIB_H
2323
#define NCNNLIB_H
2424

25+
#include "apidata.h"
26+
#include "utils/utils.hpp"
27+
28+
#include "dto/ncnn.hpp"
29+
2530
// NCNN
2631
#include "net.h"
2732
#include "ncnnmodel.h"
2833

29-
#include "apidata.h"
30-
3134
namespace dd
3235
{
3336
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -53,20 +56,17 @@ namespace dd
5356

5457
public:
5558
ncnn::Net *_net = nullptr;
56-
int _nclasses = 0;
5759
bool _timeserie = false;
58-
bool _lightmode = true;
5960

6061
private:
62+
std::shared_ptr<NcnnInitDto> _init_dto;
6163
static ncnn::UnlockedPoolAllocator _blob_pool_allocator;
6264
static ncnn::PoolAllocator _workspace_pool_allocator;
6365

6466
protected:
65-
int _threads = 1;
6667
int _old_height = -1;
67-
std::string _inputBlob = "data";
68-
std::string _outputBlob;
6968
};
69+
7070
}
7171

7272
#endif

0 commit comments

Comments
 (0)