Skip to content

Commit 100b090

Browse files
committed
feat: support Python APIs for Automatic Fallback
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
1 parent 6d3064a commit 100b090

5 files changed

+61
-0
lines changed

py/trtorch/_compile_spec.py

+27
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,23 @@ def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:
122122

123123
return info
124124

125+
def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback:
126+
info = trtorch._C.TorchFallback()
127+
if "enabled" not in fallback_info:
128+
raise KeyError("Enabled is required parameter")
129+
else:
130+
assert isinstance(fallback_info["enabled"], bool)
131+
info.enabled = fallback_info["enabled"]
132+
if "min_block_size" in fallback_info:
133+
assert isinstance(fallback_info["min_block_size"], int)
134+
info.min_block_size = fallback_info["min_block_size"]
135+
136+
if "forced_fallback_operators" in fallback_info:
137+
assert isinstance(fallback_info["forced_fallback_operators"], list)
138+
info.forced_fallback_operators = fallback_info["forced_fallback_operators"]
139+
140+
return info
141+
125142

126143
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
127144
info = trtorch._C.CompileSpec()
@@ -174,6 +191,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
174191
assert type(compile_spec["max_batch_size"]) is int
175192
info.max_batch_size = compile_spec["max_batch_size"]
176193

194+
if "torch_fallback" in compile_spec:
195+
info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"])
196+
197+
177198
return info
178199

179200

@@ -242,7 +263,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
242263
d.set_dla_core(parsed_spec.device.dla_core)
243264
d.set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
244265

266+
torch_fallback = torch.classes.tensorrt.TorchFallback()
267+
torch_fallback.set_enabled(parsed_spec.torch_fallback.enabled)
268+
torch_fallback.set_min_block_size(parsed_spec.torch_fallback.min_block_size)
269+
torch_fallback.set_forced_fallback_operators(parsed_spec.torch_fallback.forced_fallback_operators)
270+
245271
backend_spec.set_device(d)
272+
backend_spec.set_torch_fallback(fallback)
246273
backend_spec.set_op_precision(int(parsed_spec.op_precision))
247274
backend_spec.set_disable_tf32(parsed_spec.disable_tf32)
248275
backend_spec.set_refit(parsed_spec.refit)

py/trtorch/csrc/register_tensorrt_classes.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,19 @@ void RegisterTRTCompileSpec() {
2424
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
2525
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
2626

27+
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
28+
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "Fallback").def(torch::init<>());
29+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
30+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
31+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
32+
33+
2734
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
2835
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
2936
.def(torch::init<>())
3037
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
3138
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
39+
.def("set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
3240
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
3341

3442
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);

py/trtorch/csrc/tensorrt_classes.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
107107
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
108108
info.convert_info.engine_settings.device.dla_core = device.dla_core;
109109
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
110+
info.convert_info.engine_settings.torch_fallback.enabled = torch_fallback.enabled;
111+
info.convert_info.engine_settings.torch_fallback.min_block_size = torch_fallback.min_block_size;
112+
info.convert_info.engine_settings.torch_fallback.forced_fallback_operators = torch_fallback.forced_fallback_operators;
110113

111114
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
112115
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");

py/trtorch/csrc/tensorrt_classes.h

+17
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ struct Device : torch::CustomClassHolder {
7878
std::string to_str(DeviceType value);
7979
nvinfer1::DeviceType toTRTDeviceType(DeviceType value);
8080

81+
struct TorchFallback : torch::CustomClassHolder {
82+
bool enabled;
83+
int64_t min_block_size;
84+
std::vector<std::string> forced_fallback_operators;
85+
TorchFallback() : enabled(false), min_block_size(1) {}
86+
87+
ADD_FIELD_GET_SET(enabled, bool);
88+
ADD_FIELD_GET_SET(min_block_size, int64_t);
89+
ADD_FIELD_GET_SET(forced_fallback_operators, std::vector<std::string>);
90+
};
91+
8192
enum class EngineCapability : int8_t {
8293
kDEFAULT,
8394
kSAFE_GPU,
@@ -98,6 +109,10 @@ struct CompileSpec : torch::CustomClassHolder {
98109
device = *d;
99110
}
100111

112+
void setTorchFallbackIntrusive(const c10::intrusive_ptr<TorchFallback> &fb) {
113+
torch_fallback = *fb;
114+
}
115+
101116
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102117
ADD_FIELD_GET_SET(disable_tf32, bool);
103118
ADD_FIELD_GET_SET(refit, bool);
@@ -109,6 +124,7 @@ struct CompileSpec : torch::CustomClassHolder {
109124
ADD_FIELD_GET_SET(workspace_size, int64_t);
110125
ADD_FIELD_GET_SET(max_batch_size, int64_t);
111126
ADD_FIELD_GET_SET(device, Device);
127+
ADD_FIELD_GET_SET(torch_fallback, TorchFallback);
112128

113129
std::vector<InputRange> input_ranges;
114130
DataType op_precision = DataType::kFloat;
@@ -117,6 +133,7 @@ struct CompileSpec : torch::CustomClassHolder {
117133
bool debug = false;
118134
bool strict_types = false;
119135
Device device;
136+
TorchFallback torch_fallback;
120137
EngineCapability capability = EngineCapability::kDEFAULT;
121138
int64_t num_min_timing_iters = 2;
122139
int64_t num_avg_timing_iters = 1;

py/trtorch/csrc/trtorch_py.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ PYBIND11_MODULE(_C, m) {
124124
.def_readwrite("dla_core", &Device::dla_core)
125125
.def_readwrite("allow_gpu_fallback", &Device::allow_gpu_fallback);
126126

127+
py::class_<TorchFallback>(m, "TorchFallback")
128+
.def(py::init<>())
129+
.def_readwrite("enabled", &TorchFallback::enabled)
130+
.def_readwrite("min_block_size", &TorchFallback::min_block_size)
131+
.def_readwrite("forced_fallback_operators", &TorchFallback::forced_fallback_operators);
132+
127133
m.doc() =
128134
"TRTorch Internal C Bindings: Ahead of Time compilation for PyTorch JIT. A tool to convert PyTorch JIT to TensorRT";
129135
m.def(

0 commit comments

Comments
 (0)