@@ -7,26 +7,24 @@ Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via
7
7
## Example Usage
8
8
9
9
``` python
10
- import torch
11
- import torchvision
12
10
import torch_tensorrt
13
11
14
- # Get a model
15
- model = torchvision.models.alexnet(pretrained = True ).eval().cuda()
12
+ ...
16
13
17
- # Create some example data
18
- data = torch.randn((1 , 3 , 224 , 224 )).to(" cuda" )
14
+ trt_ts_module = torch_tensorrt.compile(torch_script_module,
15
+ inputs = [example_tensor, # Provide example tensor for input shape or...
16
+ torch_tensorrt.Input( # Specify input object with shape and dtype
17
+ min_shape = [1 , 3 , 224 , 224 ],
18
+ opt_shape = [1 , 3 , 512 , 512 ],
19
+ max_shape = [1 , 3 , 1024 , 1024 ],
20
+ # For static size shape=[1, 3, 224, 224]
21
+ dtype = torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
22
+ ],
23
+ enabled_precisions = {torch.half}, # Run with FP16)
19
24
20
- # Trace the module with example data
21
- traced_model = torch.jit.trace(model, [data])
25
+ result = trt_ts_module(input_data) # run inference
26
+ torch.jit.save(trt_ts_module, " trt_torchscript_module.ts " ) # save the TRT embedded Torchscript
22
27
23
- # Compile module
24
- compiled_trt_model = torch_tensorrt.compile(traced_model, {
25
- " inputs" : [torch_tensorrt.Input(data.shape)],
26
- " enabled_precisions" : {torch.float, torch.half}, # Run with FP16
27
- })
28
-
29
- results = compiled_trt_model(data.half())
30
28
```
31
29
32
30
# # Installation
0 commit comments