diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 2f8e545f3390..cc93ec147c52 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -6,6 +6,7 @@ import torch import time import os +import deepspeed.utils.nvtx from deepspeed import comm as dist from deepspeed.utils.logging import log_dist @@ -646,6 +647,9 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg if self._is_compiled: return + + # Avoid graph breaks + nvtx.enable_nvtx = False self.module.compile(backend=backend, **compile_kwargs) self._is_compiled = True