File tree 1 file changed +15
-1
lines changed
1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change 9
9
import time
10
10
import torch
11
11
12
+ ipex_enabled = False
13
+ if os .environ .get ("TS_IPEX_ENABLE" , "false" ) == "true" :
14
+ try :
15
+ import intel_extension_for_pytorch as ipex
16
+ ipex_enabled = True
17
+ except :
18
+ pass
19
+
12
20
from ..utils .util import list_classes_from_module , load_label_mapping
13
21
14
22
logger = logging .getLogger (__name__ )
@@ -73,6 +81,9 @@ def initialize(self, context):
73
81
self .model = self ._load_torchscript_model (model_pt_path )
74
82
75
83
self .model .eval ()
84
+ if ipex_enabled :
85
+ self .model = self .model .to (memory_format = torch .channels_last )
86
+ self .model = ipex .optimize (self .model , dtype = torch .float32 , level = 'O1' )
76
87
77
88
logger .debug ('Model file %s loaded successfully' , model_pt_path )
78
89
@@ -141,7 +152,10 @@ def preprocess(self, data):
141
152
Returns:
142
153
tensor: Returns the tensor data of the input
143
154
"""
144
- return torch .as_tensor (data , device = self .device )
155
+ t = torch .as_tensor (data , device = self .device )
156
+ if ipex_enabled and t .dim () == 4 :
157
+ t = t .to (memory_format = torch .channels_last )
158
+ return t
145
159
146
160
def inference (self , data , * args , ** kwargs ):
147
161
"""
You can’t perform that action at this time.
0 commit comments