Skip to content

❓ [Question] How do I load the torch tensorRT model on multiple gpus #2319

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
agunapal opened this issue Sep 14, 2023 · 2 comments · Fixed by #2325
Closed

❓ [Question] How do I load the torch tensorRT model on multiple gpus #2319

agunapal opened this issue Sep 14, 2023 · 2 comments · Fixed by #2325
Assignees
Labels
bug: triaged [verified] We can replicate the bug component: runtime question Further information is requested

Comments

@agunapal
Copy link

agunapal commented Sep 14, 2023

❓ Question

In TorchServe, we have this concept of workers. In a multi-GPU node, we can assign each GPU to a worker.

I am noticing that tensorRT model is getting loaded on GPU 0 even though we specify the correct GPU ID
for each worker.torch.jit.load(model_pt_path, map_location=self.device)

How do we load a tensorRT model in a a device id which is not 0 ?

What you have already tried

I have tried loading a torchscript model, Here, it loads on all 4 GPUs

Using torch.jit.load(model_pt_path, map_location=self.device) to load the same model on each of the 4 GPUs

2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - cuda:1
2023-09-14T18:32:19,333 [INFO ] W-9000-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,355 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - cuda:0
2023-09-14T18:32:19,356 [INFO ] W-9003-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - cuda:3
2023-09-14T18:32:19,357 [INFO ] W-9002-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - Torch TensorRT enabled
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - cuda:2
2023-09-14T18:32:19,359 [INFO ] W-9001-resnet-18_1.0-stdout MODEL_LOG - !!!!!!!!!!!!!!!!!!!
Screenshot 2023-09-14 at 11 39 36 AM

Have a simpler repro

import torch
import torch_tensorrt
model = torch.jit.load("trt_model_fp16.pt","cuda:1")
Screenshot 2023-09-14 at 1 28 20 PM

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0):3.9
  • CPU Architecture:
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: pip
  • Python version: 3.9
  • CUDA version: 11.7
  • GPU models and configuration: T4
  • Any other relevant information:

Additional context

@agunapal agunapal added the question Further information is requested label Sep 14, 2023
@gs-olive gs-olive self-assigned this Sep 18, 2023
@gs-olive
Copy link
Collaborator

gs-olive commented Sep 19, 2023

Hello - I am able to reproduce this issue, and I see that the models are running on the device ordinal they were compiled on. In general, Torch-TensorRT serialized models are able to run on any GPUs of the same type/version.

On the topic of using multiple serialized instances on multiple GPUs within the same script - I have pushed a fix in #2325 and with that PR, setting the device context in Torch should make it so that loading the model in that context will put its engine on that device. For example:

##### Compile on GPU 0 of the same type as GPU 1

##### Now load the model on GPU 1 and run inference
with torch.cuda.device(1):
    input_ = torch.randn((1, 3, 224, 224)).to("cuda:1")
    model = torch.jit.load("trt_model_fp16.pt","cuda:1")
    out = model(input_)

@agunapal
Copy link
Author

Thanks for the fix @gs-olive !

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug: triaged [verified] We can replicate the bug component: runtime question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants