@@ -564,52 +564,54 @@ namespace dd
564
564
+ out_blob);
565
565
}
566
566
567
- if (_bbox)
568
- {
569
- _outputIndex1 = _engine->getBindingIndex (" keep_count" );
570
- _buffers.resize (3 );
571
- _floatOut.resize (_max_batch_size * _top_k * 7 );
572
- _keepCount.resize (_max_batch_size);
573
- if (inputc._bw )
574
- cudaMalloc (&_buffers.data ()[_inputIndex],
575
- _max_batch_size * inputc._height * inputc._width
576
- * sizeof (float ));
577
- else
578
- cudaMalloc (&_buffers.data ()[_inputIndex],
579
- _max_batch_size * 3 * inputc._height * inputc._width
580
- * sizeof (float ));
581
- cudaMalloc (&_buffers.data ()[_outputIndex0],
582
- _max_batch_size * _top_k * 7 * sizeof (float ));
583
- cudaMalloc (&_buffers.data ()[_outputIndex1],
584
- _max_batch_size * sizeof (int ));
585
- }
586
- else if (_ctc)
587
- {
588
- throw MLLibBadParamException (
589
- " ocr not yet implemented over tensorRT backend" );
590
- }
591
- else if (_timeserie)
567
+ if (_first_predict)
592
568
{
593
- throw MLLibBadParamException (
594
- " timeseries not yet implemented over tensorRT backend" );
595
- }
596
- else // classification / regression
597
- {
598
- if (_regression)
599
- _buffers.resize (1 );
600
- else
601
- _buffers.resize (2 );
602
- _floatOut.resize (_max_batch_size * this ->_nclasses );
603
- if (inputc._bw )
604
- cudaMalloc (&_buffers.data ()[_inputIndex],
605
- _max_batch_size * inputc._height * inputc._width
606
- * sizeof (float ));
607
- else
608
- cudaMalloc (&_buffers.data ()[_inputIndex],
609
- _max_batch_size * 3 * inputc._height * inputc._width
610
- * sizeof (float ));
611
- cudaMalloc (&_buffers.data ()[_outputIndex0],
612
- _max_batch_size * _nclasses * sizeof (float ));
569
+ _first_predict = false ;
570
+
571
+ if (_bbox)
572
+ {
573
+ _outputIndex1 = _engine->getBindingIndex (" keep_count" );
574
+ _buffers.resize (3 );
575
+ _floatOut.resize (_max_batch_size * _top_k * 7 );
576
+ _keepCount.resize (_max_batch_size);
577
+ if (inputc._bw )
578
+ cudaMalloc (&_buffers.data ()[_inputIndex],
579
+ _max_batch_size * inputc._height * inputc._width
580
+ * sizeof (float ));
581
+ else
582
+ cudaMalloc (&_buffers.data ()[_inputIndex],
583
+ _max_batch_size * 3 * inputc._height
584
+ * inputc._width * sizeof (float ));
585
+ cudaMalloc (&_buffers.data ()[_outputIndex0],
586
+ _max_batch_size * _top_k * 7 * sizeof (float ));
587
+ cudaMalloc (&_buffers.data ()[_outputIndex1],
588
+ _max_batch_size * sizeof (int ));
589
+ }
590
+ else if (_ctc)
591
+ {
592
+ throw MLLibBadParamException (
593
+ " ocr not yet implemented over tensorRT backend" );
594
+ }
595
+ else if (_timeserie)
596
+ {
597
+ throw MLLibBadParamException (
598
+ " timeseries not yet implemented over tensorRT backend" );
599
+ }
600
+ else // classification / regression
601
+ {
602
+ _buffers.resize (2 );
603
+ _floatOut.resize (_max_batch_size * this ->_nclasses );
604
+ if (inputc._bw )
605
+ cudaMalloc (&_buffers.data ()[_inputIndex],
606
+ _max_batch_size * inputc._height * inputc._width
607
+ * sizeof (float ));
608
+ else
609
+ cudaMalloc (&_buffers.data ()[_inputIndex],
610
+ _max_batch_size * 3 * inputc._height
611
+ * inputc._width * sizeof (float ));
612
+ cudaMalloc (&_buffers.data ()[_outputIndex0],
613
+ _max_batch_size * _nclasses * sizeof (float ));
614
+ }
613
615
}
614
616
}
615
617
0 commit comments