@@ -122,6 +122,23 @@ def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:
122
122
123
123
return info
124
124
125
+ def _parse_torch_fallback (fallback_info : Dict [str , Any ]) -> trtorch ._C .TorchFallback :
126
+ info = trtorch ._C .TorchFallback ()
127
+ if "enabled" not in fallback_info :
128
+ raise KeyError ("Enabled is required parameter" )
129
+ else :
130
+ assert isinstance (fallback_info ["enabled" ], bool )
131
+ info .enabled = fallback_info ["enabled" ]
132
+ if "min_block_size" in fallback_info :
133
+ assert isinstance (fallback_info ["min_block_size" ], int )
134
+ info .min_block_size = fallback_info ["min_block_size" ]
135
+
136
+ if "forced_fallback_operators" in fallback_info :
137
+ assert isinstance (fallback_info ["forced_fallback_operators" ], list )
138
+ info .forced_fallback_operators = fallback_info ["forced_fallback_operators" ]
139
+
140
+ return info
141
+
125
142
126
143
def _parse_compile_spec (compile_spec : Dict [str , Any ]) -> trtorch ._C .CompileSpec :
127
144
info = trtorch ._C .CompileSpec ()
@@ -174,6 +191,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
174
191
assert type (compile_spec ["max_batch_size" ]) is int
175
192
info .max_batch_size = compile_spec ["max_batch_size" ]
176
193
194
+ if "torch_fallback" in compile_spec :
195
+ info .torch_fallback = _parse_torch_fallback (compile_spec ["torch_fallback" ])
196
+
197
+
177
198
return info
178
199
179
200
@@ -242,7 +263,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
242
263
d .set_dla_core (parsed_spec .device .dla_core )
243
264
d .set_allow_gpu_fallback (parsed_spec .device .allow_gpu_fallback )
244
265
266
+ torch_fallback = torch .classes .tensorrt .TorchFallback ()
267
+ torch_fallback .set_enabled (parsed_spec .torch_fallback .enabled )
268
+ torch_fallback .set_min_block_size (parsed_spec .torch_fallback .min_block_size )
269
+ torch_fallback .set_forced_fallback_operators (parsed_spec .torch_fallback .forced_fallback_operators )
270
+
245
271
backend_spec .set_device (d )
272
+ backend_spec .set_torch_fallback (fallback )
246
273
backend_spec .set_op_precision (int (parsed_spec .op_precision ))
247
274
backend_spec .set_disable_tf32 (parsed_spec .disable_tf32 )
248
275
backend_spec .set_refit (parsed_spec .refit )
0 commit comments