Skip to content

Commit 9760fe3

Browse files
committed
feat(jetson): Support for Jetpack 4.6
This commit adds support for Jetpack 4.6. Users should now add the --platforms flag to bazel compilation to target the jetpack version e.g. `bazel build //:libtrtorch --platforms //toolchains:jetpack_4.6` By default setup.py now expects Jetpack 4.6. To override add the `--jetpack-version 4.5` flag Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 744b417 commit 9760fe3

File tree

13 files changed

+101
-34
lines changed

13 files changed

+101
-34
lines changed

core/conversion/conversionctx/ConversionCtx.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct Device {
2424
};
2525

2626
struct BuilderSettings {
27-
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
27+
std::set<nvinfer1::DataType> enabled_precisions = {};
2828
bool sparse_weights = false;
2929
bool disable_tf32 = false;
3030
bool refit = false;

docsrc/tutorials/installation.rst

+16-24
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ Install or compile a build of PyTorch/LibTorch for aarch64
237237

238238
NVIDIA hosts builds the latest release branch for Jetson here:
239239

240-
https://forums.developer.nvidia.com/t/pytorch-for-jetson-nano-version-1-5-0-now-available/72048
240+
https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-9-0-now-available/72048
241241

242242

243243
Enviorment Setup
@@ -285,29 +285,10 @@ To build natively on aarch64-linux-gnu platform, configure the ``WORKSPACE`` wit
285285
# strip_prefix = "TensorRT-7.1.3.4"
286286
#)
287287
288+
NOTE: You may also need to configure the CUDA version to 10.2 by setting the path for the cuda new_local_repository
288289
289-
2. Disable Python API testing dependencies:
290-
291-
.. code-block:: shell
292-
293-
#pip3_import(
294-
# name = "trtorch_py_deps",
295-
# requirements = "//py:requirements.txt"
296-
#)
297-
298-
#load("@trtorch_py_deps//:requirements.bzl", "pip_install")
299-
#pip_install()
300-
301-
#pip3_import(
302-
# name = "py_test_deps",
303-
# requirements = "//tests/py:requirements.txt"
304-
#)
305-
306-
#load("@py_test_deps//:requirements.bzl", "pip_install")
307-
#pip_install()
308290
309-
310-
3. Configure the correct paths to directory roots containing local dependencies in the ``new_local_repository`` rules:
291+
2. Configure the correct paths to directory roots containing local dependencies in the ``new_local_repository`` rules:
311292

312293
NOTE: If you installed PyTorch using a pip package, the correct path is the path to the root of the python torch package.
313294
In the case that you installed with ``sudo pip install`` this will be ``/usr/local/lib/python3.6/dist-packages/torch``.
@@ -346,19 +327,30 @@ use that library, set the paths to the same path but when you compile make sure
346327
Compile C++ Library and Compiler CLI
347328
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
348329

330+
NOTE: Due to shifting dependency locations between Jetpack 4.5 and 4.6 there is a now a flag to inform bazel of the Jetpack version
331+
332+
.. code-block:: shell
333+
334+
--platforms //toolchains:jetpack_4.x
335+
336+
349337
Compile TRTorch library using bazel command:
350338

351339
.. code-block:: shell
352340
353-
bazel build //:libtrtorch
341+
bazel build //:libtrtorch --platforms //toolchains:jetpack_4.6
354342
355343
Compile Python API
356344
^^^^^^^^^^^^^^^^^^^^
357345

346+
NOTE: Due to shifting dependencies locations between Jetpack 4.5 and Jetpack 4.6 there is now a flag for ``setup.py`` which sets the jetpack version (default: 4.6)
347+
358348
Compile the Python API using the following command from the ``//py`` directory:
359349

360350
.. code-block:: shell
361351
362352
python3 setup.py install --use-cxx11-abi
363353
364-
If you have a build of PyTorch that uses Pre-CXX11 ABI drop the ``--use-cxx11-abi`` flag
354+
If you have a build of PyTorch that uses Pre-CXX11 ABI drop the ``--use-cxx11-abi`` flag
355+
356+
If you are building for Jetpack 4.5 add the ``--jetpack-version 4.5`` flag

py/setup.py

+25
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,35 @@
1313
from shutil import copyfile, rmtree
1414

1515
import subprocess
16+
import platform
17+
import warnings
1618

1719
dir_path = os.path.dirname(os.path.realpath(__file__))
1820

1921
__version__ = '0.4.0a0'
2022

2123
CXX11_ABI = False
2224

25+
JETPACK_VERSION = None
26+
2327
if "--use-cxx11-abi" in sys.argv:
2428
sys.argv.remove("--use-cxx11-abi")
2529
CXX11_ABI = True
2630

31+
if platform.uname().processor == "aarch64":
32+
if "--jetpack-version" in sys.argv:
33+
version_idx = sys.argv.index("--jetpack-version") + 1
34+
version = sys.argv[version_idx]
35+
sys.argv.remove(version)
36+
sys.argv.remove("--jetpack-version")
37+
if version == "4.5":
38+
JETPACK_VERSION = "4.5"
39+
elif version == "4.6":
40+
JETPACK_VERSION = "4.6"
41+
if not JETPACK_VERSION:
42+
warnings.warn("Assuming jetpack version to be 4.6, if not use the --jetpack-version option")
43+
JETPACK_VERSION = "4.6"
44+
2745

2846
def which(program):
2947
import os
@@ -66,6 +84,13 @@ def build_libtrtorch_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=Fa
6684
else:
6785
print("using CXX11 ABI build")
6886

87+
if JETPACK_VERSION == "4.5":
88+
cmd.append("--platforms=//toolchains:jetpack_4.5")
89+
print("Jetpack version: 4.5")
90+
elif JETPACK_VERSION == "4.6":
91+
cmd.append("--platforms=//toolchains:jetpack_4.6")
92+
print("Jetpack version: 4.6")
93+
6994
print("building libtrtorch")
7095
status_code = subprocess.run(cmd).returncode
7196

py/trtorch/csrc/tensorrt_classes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct CompileSpec : torch::CustomClassHolder {
155155

156156
std::vector<Input> inputs;
157157
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
158-
std::set<DataType> enabled_precisions = {DataType::kFloat};
158+
std::set<DataType> enabled_precisions = {};
159159
bool sparse_weights = false;
160160
bool disable_tf32 = false;
161161
bool refit = false;

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gtest/gtest.h"
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/torch.h"
67

78
TEST(Evaluators, DivIntEvaluatesCorrectly) {
89
const auto graph = R"IR(

tests/cpp/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ cc_test(
7777
deps = [
7878
":cpp_api_test",
7979
],
80+
timeout="long"
8081
)
8182

8283
cc_test(

tests/modules/hub.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torchvision.models as models
55
import timm
66

7+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
8+
79
models = {
810
"alexnet": {
911
"model": models.alexnet(pretrained=True),

tests/modules/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
timm==v0.4.12

tests/py/test_api_dla.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_compile_traced(self):
3737
"dla_core": 0,
3838
"allow_gpu_fallback": True
3939
},
40-
"enabled_precision": {torch.float, torch.half}
40+
"enabled_precisions": {torch.half}
4141
}
4242

4343
trt_mod = trtorch.compile(self.traced_model, compile_spec)
@@ -53,7 +53,7 @@ def test_compile_script(self):
5353
"dla_core": 0,
5454
"allow_gpu_fallback": True
5555
},
56-
"enabled_precision": {torch.float, torch.half}
56+
"enabled_precisions": {torch.half}
5757
}
5858

5959
trt_mod = trtorch.compile(self.scripted_model, compile_spec)

third_party/cublas/BUILD

+18-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,21 @@ package(default_visibility = ["//visibility:public"])
33
# NOTE: This BUILD file is only really targeted at aarch64, the rest of the configuration is just to satisfy bazel, x86 uses the cublas source from the CUDA build file since it will be versioned with CUDA.
44

55
config_setting(
6-
name = "aarch64_linux",
6+
name = "jetpack_4.5",
77
constraint_values = [
88
"@platforms//cpu:aarch64",
99
"@platforms//os:linux",
10-
],
10+
"@//toolchains/jetpack:4.5"
11+
]
12+
)
13+
14+
config_setting(
15+
name = "jetpack_4.6",
16+
constraint_values = [
17+
"@platforms//cpu:aarch64",
18+
"@platforms//os:linux",
19+
"@//toolchains/jetpack:4.6"
20+
]
1121
)
1222

1323
config_setting(
@@ -20,7 +30,8 @@ config_setting(
2030
cc_library(
2131
name = "cublas_headers",
2232
hdrs = select({
23-
":aarch64_linux": ["include/cublas.h"] + glob(["usr/include/cublas+.h"]),
33+
":jetpack_4.5": ["include/cublas.h"] + glob(["usr/include/cublas+.h"]),
34+
":jetpack_4.6": ["local/cuda/include/cublas.h"] + glob(["usr/cuda/include/cublas+.h"]),
2435
"//conditions:default": ["local/cuda/include/cublas.h"] + glob(["usr/cuda/include/cublas+.h"]),
2536
}),
2637
includes = ["include/"],
@@ -30,7 +41,8 @@ cc_library(
3041
cc_import(
3142
name = "cublas_lib",
3243
shared_library = select({
33-
":aarch64_linux": "lib/aarch64-linux-gnu/libcublas.so",
44+
":jetpack_4.5": "lib/aarch64-linux-gnu/libcublas.so",
45+
":jetpack_4.6": "local/cuda/targets/aarch64-linux/lib/libcublas.so",
3446
":windows": "lib/x64/cublas.lib",
3547
"//conditions:default": "local/cuda/targets/x86_64-linux/lib/libcublas.so",
3648
}),
@@ -40,7 +52,8 @@ cc_import(
4052
cc_import(
4153
name = "cublas_lt_lib",
4254
shared_library = select({
43-
":aarch64_linux": "lib/aarch64-linux-gnu/libcublasLt.so",
55+
":jetpack_4.5": "lib/aarch64-linux-gnu/libcublasLt.so",
56+
":jetpack_4.6": "local/cuda/targets/aarch64-linux/lib/libcublasLt.so",
4457
"//conditions:default": "local/cuda/targets/x86_64-linux/lib/libcublasLt.so",
4558
}),
4659
visibility = ["//visibility:private"],

third_party/tensorrt/local/BUILD

+4-1
Original file line numberDiff line numberDiff line change
@@ -319,5 +319,8 @@ cc_library(
319319
],
320320
linkopts = [
321321
"-lpthread",
322-
]
322+
] + select({
323+
":aarch64_linux": ["-Wl,--no-as-needed -ldl -lrt -Wl,--as-needed"],
324+
"//conditions:default": []
325+
})
323326
)

toolchains/BUILD

+18
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,21 @@ platform(
77
"@platforms//cpu:aarch64",
88
],
99
)
10+
11+
platform(
12+
name = "jetpack_4.5",
13+
constraint_values = [
14+
"@platforms//os:linux",
15+
"@platforms//cpu:aarch64",
16+
"@//toolchains/jetpack:4.5"
17+
]
18+
)
19+
20+
platform(
21+
name = "jetpack_4.6",
22+
constraint_values = [
23+
"@platforms//os:linux",
24+
"@platforms//cpu:aarch64",
25+
"@//toolchains/jetpack:4.6"
26+
]
27+
)

toolchains/jetpack/BUILD

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
constraint_setting(name = "jetpack")
4+
constraint_value(
5+
name = "4.5",
6+
constraint_setting = ":jetpack"
7+
)
8+
constraint_value(
9+
name = "4.6",
10+
constraint_setting = ":jetpack"
11+
)

0 commit comments

Comments
 (0)