Skip to content

Commit 9eae269

Browse files
authoredNov 11, 2021
Merge pull request #706 from NVIDIA/update-py-readme
Update py/README.md
2 parents 6a4daef + cbcd63a commit 9eae269

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed
 

‎py/README.md

+13-15
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,24 @@ Torch-TensorRT is a compiler for PyTorch/TorchScript, targeting NVIDIA GPUs via
77
## Example Usage
88

99
``` python
10-
import torch
11-
import torchvision
1210
import torch_tensorrt
1311

14-
# Get a model
15-
model = torchvision.models.alexnet(pretrained=True).eval().cuda()
12+
...
1613

17-
# Create some example data
18-
data = torch.randn((1, 3, 224, 224)).to("cuda")
14+
trt_ts_module = torch_tensorrt.compile(torch_script_module,
15+
inputs = [example_tensor, # Provide example tensor for input shape or...
16+
torch_tensorrt.Input( # Specify input object with shape and dtype
17+
min_shape=[1, 3, 224, 224],
18+
opt_shape=[1, 3, 512, 512],
19+
max_shape=[1, 3, 1024, 1024],
20+
# For static size shape=[1, 3, 224, 224]
21+
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
22+
],
23+
enabled_precisions = {torch.half}, # Run with FP16)
1924

20-
# Trace the module with example data
21-
traced_model = torch.jit.trace(model, [data])
25+
result = trt_ts_module(input_data) # run inference
26+
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedded Torchscript
2227

23-
# Compile module
24-
compiled_trt_model = torch_tensorrt.compile(traced_model, {
25-
"inputs": [torch_tensorrt.Input(data.shape)],
26-
"enabled_precisions": {torch.float, torch.half}, # Run with FP16
27-
})
28-
29-
results = compiled_trt_model(data.half())
3028
```
3129

3230
## Installation

0 commit comments

Comments
 (0)