Skip to content

Upstream 3 features to fx_ts_compat: MS, VC, Optimization Level #1935

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def run(
timing_cache=None,
profiling_verbosity=None,
tactic_sources=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Expand Down Expand Up @@ -227,6 +230,18 @@ def run(
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
if version_compatible:
_LOGGER.info(f"Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if optimization_level is not None:
_LOGGER.info(f"Using optimization level {optimization_level}")
builder_config.builder_optimization_level = optimization_level

if lower_precision == LowerPrecision.FP16:
builder_config.set_flag(trt.BuilderFlag.FP16)

Expand Down Expand Up @@ -264,6 +279,7 @@ def run(
_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory")

return TRTInterpreterResult(
engine, self._input_names, self._output_names, serialized_cache
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/fx_ts_compat/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
tactic_sources=self.lower_setting.tactic_sources,
max_aux_streams=self.lower_setting.max_aux_streams,
version_compatible=self.lower_setting.version_compatible,
optimization_level=self.lower_setting.optimization_level,
)

# Update timing cache file if needed
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -96,3 +99,6 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
max_aux_streams: Optional[int] = None
version_compatible: bool = False
optimization_level: Optional[int] = None
90 changes: 48 additions & 42 deletions py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
# on pass that failed accuracy check.
def validate_inference(
rtol=None, atol=None, device=torch.device(torch.cuda.current_device())
rtol=None,
atol=None,
device=torch.device(torch.cuda.current_device()),
suppress_accuracy_check_failure=True,
):
def _validate_inference(pass_: PassFunc) -> PassFunc:
"""
Expand All @@ -141,48 +144,51 @@ def pass_with_validation(
*args,
**kwargs,
) -> fx.GraphModule:
input_tensors = extract_example_tensors_from_input(input, device)
res0 = module(*input_tensors)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input_tensors)
tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
if suppress_accuracy_check_failure:
return pass_(module, input, *args, **kwargs)
else:
input_tensors = extract_example_tensors_from_input(input, device)
res0 = module(*input_tensors)
processed_module = pass_(module, input, *args, **kwargs)
res1 = processed_module(*input_tensors)
tensor_res_0 = _collect_tensors(res0)
tensor_res_1 = _collect_tensors(res1)
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE

for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
x = x.cpu()
y = y.cpu()
try:
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module
except Exception as e:
if relax_accuracy_check_failure:
_LOGGER.error(f"{e}")
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
new_atol = kwargs2["atol"]
new_rtol = kwargs2["rtol"]
_LOGGER.info(
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
)
torch.testing.assert_close(x, y, **kwargs2)
return processed_module
else:
raise e

return processed_module

return pass_with_validation

Expand Down