Skip to content

Commit 44ca0db

Browse files
authored
Update base_handler.py
1 parent 59e7f42 commit 44ca0db

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

ts/torch_handler/base_handler.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
import time
1010
import torch
1111

12+
ipex_enabled = False
13+
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
14+
try:
15+
import intel_extension_for_pytorch as ipex
16+
ipex_enabled = True
17+
except:
18+
pass
19+
1220
from ..utils.util import list_classes_from_module, load_label_mapping
1321

1422
logger = logging.getLogger(__name__)
@@ -73,6 +81,9 @@ def initialize(self, context):
7381
self.model = self._load_torchscript_model(model_pt_path)
7482

7583
self.model.eval()
84+
if ipex_enabled:
85+
self.model = self.model.to(memory_format=torch.channels_last)
86+
self.model = ipex.optimize(self.model, dtype=torch.float32, level='O1')
7687

7788
logger.debug('Model file %s loaded successfully', model_pt_path)
7889

@@ -141,7 +152,10 @@ def preprocess(self, data):
141152
Returns:
142153
tensor: Returns the tensor data of the input
143154
"""
144-
return torch.as_tensor(data, device=self.device)
155+
t = torch.as_tensor(data, device=self.device)
156+
if ipex_enabled and t.dim() == 4:
157+
t = t.to(memory_format=torch.channels_last)
158+
return t
145159

146160
def inference(self, data, *args, **kwargs):
147161
"""

0 commit comments

Comments
 (0)