Skip to content

Commit f560fec

Browse files
committed
chore: Proper logging message and rebase
1 parent 11886fe commit f560fec

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

Diff for: py/torch_tensorrt/runtime/_cudagraphs.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Union
33

44
import torch
55
import torch_tensorrt
@@ -88,17 +88,15 @@ def __enter__(self) -> torch.nn.Module:
8888
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)
8989

9090
logger.debug(
91-
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
91+
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
9292
)
9393
return CudaGraphsTorchTensorRTModule(self.compiled_module)
9494
else:
9595
if num_trt_module > 0:
96-
logger.debug(
97-
"There is no graph breaks. Using original module for cuda graphs"
98-
)
96+
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
9997
else:
100-
logger.warning(
101-
"Please consider dynamo if there is graph breaks. Using original module for cuda graphs"
98+
logger.debug(
99+
"Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode"
102100
)
103101
# Enable cudagraphs for TRT submodule
104102
set_cudagraphs_mode(True)
@@ -110,6 +108,6 @@ def __exit__(self, *args: Any) -> None:
110108

111109

112110
def enable_cudagraphs(
113-
compiled_module: torch.nn.Module,
111+
compiled_module: Union[torch.fx.GraphModule, torch.nn.Module],
114112
) -> _CudagraphsContextManager:
115113
return _CudagraphsContextManager(compiled_module)

0 commit comments

Comments
 (0)