diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index e1e9407070..a15202e1b3 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -48,13 +48,19 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup throw new WorkerInitializationException("Failed get TS home directory", e); } - String[] args = new String[6]; + String[] args = new String[12]; args[0] = EnvironmentUtils.getPythonRunTime(model); - args[1] = new File(workingDir, "ts/model_service_worker.py").getAbsolutePath(); - args[2] = "--sock-type"; - args[3] = connector.getSocketType(); - args[4] = connector.isUds() ? "--sock-name" : "--port"; - args[5] = connector.getSocketPath(); + args[1] = "-m"; + args[2] = "intel_extension_for_pytorch.cpu.launch"; + args[3] = "--ninstances"; + args[4] = "1"; + args[5] = "--ncore_per_instance"; + args[6] = "22"; + args[7] = new File(workingDir, "ts/model_service_worker.py").getAbsolutePath(); + args[8] = "--sock-type"; + args[9] = connector.getSocketType(); + args[10] = connector.isUds() ? "--sock-name" : "--port"; + args[11] = connector.getSocketPath(); String[] envp = EnvironmentUtils.getEnvString( diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 199b4a8894..d23bb68a75 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -9,6 +9,13 @@ import time import torch +ipex_enabled = False +try: + import intel_extension_for_pytorch as ipex + ipex_enabled = True +except: + pass + from ..utils.util import list_classes_from_module, load_label_mapping logger = logging.getLogger(__name__) @@ -73,6 +80,9 @@ def initialize(self, context): self.model = self._load_torchscript_model(model_pt_path) self.model.eval() + if ipex_enabled: + self.model = self.model.to(memory_format=torch.channels_last) + self.model = ipex.optimize(self.model, dtype=torch.float32, level='O1') logger.debug('Model file %s loaded successfully', model_pt_path) @@ -141,7 +151,10 @@ def preprocess(self, data): Returns: tensor: Returns the tensor data of the input """ - return torch.as_tensor(data, device=self.device) + t = torch.as_tensor(data, device=self.device) + if ipex_enabled: + t = t.to(memory_format=torch.channels_last) + return t def inference(self, data, *args, **kwargs): """