22
22
#include " outputconnectorstrategy.h"
23
23
#include < thread>
24
24
#include < algorithm>
25
- #include " utils/utils.hpp"
26
25
27
26
// NCNN
28
27
#include " ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
53
52
{
54
53
this ->_libname = " ncnn" ;
55
54
_net = new ncnn::Net ();
56
- _net->opt .num_threads = _threads ;
55
+ _net->opt .num_threads = 1 ;
57
56
_net->opt .blob_allocator = &_blob_pool_allocator;
58
57
_net->opt .workspace_allocator = &_workspace_pool_allocator;
59
- _net->opt .lightmode = _lightmode ;
58
+ _net->opt .lightmode = true ;
60
59
}
61
60
62
61
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -69,12 +68,9 @@ namespace dd
69
68
this ->_libname = " ncnn" ;
70
69
_net = tl._net ;
71
70
tl._net = nullptr ;
72
- _nclasses = tl._nclasses ;
73
- _threads = tl._threads ;
74
71
_timeserie = tl._timeserie ;
75
72
_old_height = tl._old_height ;
76
- _inputBlob = tl._inputBlob ;
77
- _outputBlob = tl._outputBlob ;
73
+ _init_dto = tl._init_dto ;
78
74
}
79
75
80
76
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -94,6 +90,8 @@ namespace dd
94
90
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
95
91
TMLModel>::init_mllib(const APIData &ad)
96
92
{
93
+ _init_dto = ad.createSharedDTO <NcnnInitDto>();
94
+
97
95
bool use_fp32 = (ad.has (" datatype" )
98
96
&& ad.get (" datatype" ).get <std::string>()
99
97
== " fp32" ); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124
122
_old_height = this ->_inputc .height ();
125
123
_net->set_input_h (_old_height);
126
124
127
- if (ad.has (" nclasses" ))
128
- _nclasses = ad.get (" nclasses" ).get <int >();
129
-
130
- if (ad.has (" threads" ))
131
- _threads = ad.get (" threads" ).get <int >();
132
- else
133
- _threads = dd_utils::my_hardware_concurrency ();
134
-
135
125
_timeserie = this ->_inputc ._timeserie ;
136
126
if (_timeserie)
137
127
this ->_mltype = " timeserie" ;
138
128
139
- if (ad.has (" lightmode" ))
140
- {
141
- _lightmode = ad.get (" lightmode" ).get <bool >();
142
- _net->opt .lightmode = _lightmode;
143
- }
144
-
145
- // setting the value of Input Layer
146
- if (ad.has (" inputblob" ))
147
- {
148
- _inputBlob = ad.get (" inputblob" ).get <std::string>();
149
- }
150
- // setting the final Output Layer
151
- if (ad.has (" outputblob" ))
152
- {
153
- _outputBlob = ad.get (" outputblob" ).get <std::string>();
154
- }
155
-
129
+ _net->opt .lightmode = _init_dto->lightmode ;
156
130
_blob_pool_allocator.set_size_compare_ratio (0 .0f );
157
131
_workspace_pool_allocator.set_size_compare_ratio (0 .5f );
158
132
model_type (this ->_mlmodel ._params , this ->_mltype );
@@ -233,7 +207,10 @@ namespace dd
233
207
234
208
// Extract detection or classification
235
209
int ret = 0 ;
236
- std::string out_blob = _outputBlob;
210
+ std::string out_blob;
211
+ if (_init_dto->outputBlob != nullptr )
212
+ out_blob = _init_dto->outputBlob ->std_str ();
213
+
237
214
if (out_blob.empty ())
238
215
{
239
216
if (bbox == true )
@@ -262,11 +239,11 @@ namespace dd
262
239
{
263
240
best = ad_output.get (" best" ).get <int >();
264
241
}
265
- if (best == -1 || best > _nclasses )
266
- best = _nclasses ;
242
+ if (best == -1 || best > _init_dto-> nclasses )
243
+ best = _init_dto-> nclasses ;
267
244
268
245
// for loop around batch size
269
- #pragma omp parallel for num_threads(_threads )
246
+ #pragma omp parallel for num_threads(*_init_dto->threads )
270
247
for (size_t b = 0 ; b < inputc._ids .size (); b++)
271
248
{
272
249
std::vector<double > probs;
@@ -276,8 +253,8 @@ namespace dd
276
253
APIData rad;
277
254
278
255
ncnn::Extractor ex = _net->create_extractor ();
279
- ex.set_num_threads (_threads );
280
- ex.input (_inputBlob. c_str (), inputc._in .at (b));
256
+ ex.set_num_threads (_init_dto-> threads );
257
+ ex.input (_init_dto-> inputBlob -> c_str (), inputc._in .at (b));
281
258
282
259
ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
283
260
if (ret == -1 )
@@ -423,7 +400,8 @@ namespace dd
423
400
} // end for batch_size
424
401
425
402
tout.add_results (vrad);
426
- out.add (" nclasses" , this ->_nclasses );
403
+ int nclasses = this ->_init_dto ->nclasses ;
404
+ out.add (" nclasses" , nclasses);
427
405
if (bbox == true )
428
406
out.add (" bbox" , true );
429
407
out.add (" roi" , false );
0 commit comments