7
7
8
8
#include < string>
9
9
10
- #include " sherpa-onnx/csrc/parse-options.h"
11
- #include " sherpa-onnx/csrc/macros.h"
12
10
#include " onnxruntime_cxx_api.h" // NOLINT
11
+ #include " sherpa-onnx/csrc/macros.h"
12
+ #include " sherpa-onnx/csrc/parse-options.h"
13
13
14
14
namespace sherpa_onnx {
15
15
@@ -40,25 +40,23 @@ struct TensorrtConfig {
40
40
41
41
TensorrtConfig () = default ;
42
42
TensorrtConfig (int64_t trt_max_workspace_size,
43
- int32_t trt_max_partition_iterations,
44
- int32_t trt_min_subgraph_size,
45
- bool trt_fp16_enable,
46
- bool trt_detailed_build_log,
47
- bool trt_engine_cache_enable,
48
- bool trt_timing_cache_enable,
49
- const std::string &trt_engine_cache_path,
50
- const std::string &trt_timing_cache_path,
51
- bool trt_dump_subgraphs)
43
+ int32_t trt_max_partition_iterations,
44
+ int32_t trt_min_subgraph_size, bool trt_fp16_enable,
45
+ bool trt_detailed_build_log, bool trt_engine_cache_enable,
46
+ bool trt_timing_cache_enable,
47
+ const std::string &trt_engine_cache_path,
48
+ const std::string &trt_timing_cache_path,
49
+ bool trt_dump_subgraphs)
52
50
: trt_max_workspace_size(trt_max_workspace_size),
53
- trt_max_partition_iterations (trt_max_partition_iterations),
54
- trt_min_subgraph_size(trt_min_subgraph_size),
55
- trt_fp16_enable(trt_fp16_enable),
56
- trt_detailed_build_log(trt_detailed_build_log),
57
- trt_engine_cache_enable(trt_engine_cache_enable),
58
- trt_timing_cache_enable(trt_timing_cache_enable),
59
- trt_engine_cache_path(trt_engine_cache_path),
60
- trt_timing_cache_path(trt_timing_cache_path),
61
- trt_dump_subgraphs(trt_dump_subgraphs) {}
51
+ trt_max_partition_iterations (trt_max_partition_iterations),
52
+ trt_min_subgraph_size(trt_min_subgraph_size),
53
+ trt_fp16_enable(trt_fp16_enable),
54
+ trt_detailed_build_log(trt_detailed_build_log),
55
+ trt_engine_cache_enable(trt_engine_cache_enable),
56
+ trt_timing_cache_enable(trt_timing_cache_enable),
57
+ trt_engine_cache_path(trt_engine_cache_path),
58
+ trt_timing_cache_path(trt_timing_cache_path),
59
+ trt_dump_subgraphs(trt_dump_subgraphs) {}
62
60
63
61
void Register (ParseOptions *po);
64
62
bool Validate () const ;
@@ -74,15 +72,15 @@ struct ProviderConfig {
74
72
// device only used for cuda and trt
75
73
76
74
ProviderConfig () = default ;
77
- ProviderConfig (const std::string &provider,
78
- int32_t device)
75
+ ProviderConfig (const std::string &provider, int32_t device)
79
76
: provider(provider), device(device) {}
80
77
ProviderConfig (const TensorrtConfig &trt_config,
81
- const CudaConfig &cuda_config,
82
- const std::string &provider,
83
- int32_t device)
84
- : trt_config(trt_config), cuda_config(cuda_config),
85
- provider (provider), device(device) {}
78
+ const CudaConfig &cuda_config, const std::string &provider,
79
+ int32_t device)
80
+ : trt_config(trt_config),
81
+ cuda_config (cuda_config),
82
+ provider(provider),
83
+ device(device) {}
86
84
87
85
void Register (ParseOptions *po);
88
86
bool Validate () const ;
0 commit comments