File tree 1 file changed +6
-8
lines changed
py/torch_tensorrt/runtime
1 file changed +6
-8
lines changed Original file line number Diff line number Diff line change 1
1
import logging
2
- from typing import Any
2
+ from typing import Any , Union
3
3
4
4
import torch
5
5
import torch_tensorrt
@@ -88,17 +88,15 @@ def __enter__(self) -> torch.nn.Module:
88
88
torch .ops .tensorrt .set_cudagraphs_mode (_PY_RT_CUDAGRAPHS )
89
89
90
90
logger .debug (
91
- f" { num_torch_module } torch modules are in subgraphs. Using wrapper module for cuda graphs "
91
+ "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule "
92
92
)
93
93
return CudaGraphsTorchTensorRTModule (self .compiled_module )
94
94
else :
95
95
if num_trt_module > 0 :
96
- logger .debug (
97
- "There is no graph breaks. Using original module for cuda graphs"
98
- )
96
+ logger .debug ("No graph breaks detected, using runtime cudagraphs mode" )
99
97
else :
100
- logger .warning (
101
- "Please consider dynamo if there is graph breaks. Using original module for cuda graphs "
98
+ logger .debug (
99
+ "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode "
102
100
)
103
101
# Enable cudagraphs for TRT submodule
104
102
set_cudagraphs_mode (True )
@@ -110,6 +108,6 @@ def __exit__(self, *args: Any) -> None:
110
108
111
109
112
110
def enable_cudagraphs (
113
- compiled_module : torch .nn .Module ,
111
+ compiled_module : Union [ torch .fx . GraphModule , torch . nn .Module ] ,
114
112
) -> _CudagraphsContextManager :
115
113
return _CudagraphsContextManager (compiled_module )
You can’t perform that action at this time.
0 commit comments