Skip to content

Commit 4a815e0

Browse files
committed
Merge branch 'main' into export_merge
2 parents 87f4627 + 4b608f0 commit 4a815e0

File tree

167 files changed

+1444
-345
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

167 files changed

+1444
-345
lines changed

.github/workflows/assigner.yml

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ on:
1616

1717
jobs:
1818
assign:
19+
permissions:
20+
contents: read
21+
pull-requests: write
1922
runs-on: ubuntu-latest
2023
steps:
2124
- name: Checkout

.github/workflows/build-test.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ jobs:
2626

2727
build:
2828
needs: generate-matrix
29+
permissions:
30+
id-token: write
31+
contents: read
2932
strategy:
3033
fail-fast: false
3134
matrix:
@@ -50,9 +53,6 @@ jobs:
5053
package-name: ${{ matrix.package-name }}
5154
smoke-test-script: ${{ matrix.smoke-test-script }}
5255
trigger-event: ${{ github.event_name }}
53-
secrets:
54-
AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
55-
AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
5656

5757
tests-py-torchscript-fe:
5858
name: Test torchscript frontend [Python]

.github/workflows/label.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ on: [pull_request_target]
1010

1111
jobs:
1212
label:
13-
13+
permissions:
14+
contents: read
15+
pull-requests: write
1416
runs-on: ubuntu-latest
15-
1617
steps:
17-
- uses: actions/labeler@v2
18+
- uses: actions/labeler@v4
1819
with:
1920
repo-token: "${{ secrets.GITHUB_TOKEN }}"
2021
configuration-path: .github/pr-labels.yml

.github/workflows/linux-test.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
CU_VERSION: ${{ matrix.desired_cuda }}
6868
SCRIPT: ${{ inputs.script }}
6969
RUNNER_TEST_RESULTS_DIR: /tmp/test_results
70+
ARCH: ${{ inputs.architecture }}
7071
name: ${{ inputs.job-name }}-${{ matrix.desired_cuda }}
7172
runs-on: ${{ matrix.validation_runner }}
7273
container:
@@ -100,6 +101,8 @@ jobs:
100101
ref: ${{ inputs.ref }}
101102
setup-miniconda: ${{ inputs.setup-miniconda }}
102103
python-version: ${{ env.PYTHON_VERSION }}
104+
cuda-version: ${{ env.CU_VERSION }}
105+
arch: ${{ env.ARCH }}
103106
- name: Run Pre-Script with Caching
104107
if: ${{ inputs.pre-script != '' }}
105108
uses: ./test-infra/.github/actions/run-script-with-cache
@@ -191,4 +194,4 @@ jobs:
191194

192195
concurrency:
193196
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
194-
cancel-in-progress: true
197+
cancel-in-progress: true

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.2.0.dev (latest nightly) (built with CUDA 12.1)
119+
- Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1)
120120
- CUDA 12.1
121121
- cuDNN 8.9.5
122122
- TensorRT 8.6.1

core/conversion/converters/BUILD

100755100644
+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ cc_library(
6666
"impl/einsum.cpp",
6767
"impl/element_wise.cpp",
6868
"impl/expand.cpp",
69+
"impl/internal_ops.cpp",
6970
"impl/interpolate.cpp",
7071
"impl/layer_norm.cpp",
7172
"impl/linear.cpp",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
#include "torch/torch.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
13+
{"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor",
14+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
// Converter for internal op used in unpack_scaled_dot_product_attention
16+
// We don't have visibility to check types during lowering and can't introduce conditionals so do type specific
17+
// specialization here
18+
auto in = args[0].ITensorOrFreeze(ctx);
19+
auto out = in;
20+
if (in->getType() == nvinfer1::DataType::kBOOL) {
21+
auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
22+
TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask");
23+
not_layer->setName((util::node_info(n) + "_not").c_str());
24+
auto neg_inf = torch::tensor(-std::numeric_limits<float>::infinity());
25+
auto neg_inf_itensor = tensor_to_const(ctx, neg_inf);
26+
auto prod_layer = add_elementwise(
27+
ctx,
28+
nvinfer1::ElementWiseOperation::kPROD,
29+
not_layer->getOutput(0),
30+
neg_inf_itensor,
31+
util::node_info(n) + "_mul");
32+
auto add_layer = add_elementwise(
33+
ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add");
34+
out = add_layer->getOutput(0);
35+
}
36+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);
37+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
38+
LOG_DEBUG("Output tensor type: " << out_tensor->getType());
39+
return true;
40+
}});
41+
} // namespace
42+
} // namespace impl
43+
} // namespace converters
44+
} // namespace conversion
45+
} // namespace core
46+
} // namespace torch_tensorrt

core/conversion/converters/impl/unary.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ auto logical_not_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
7979
return true;
8080
}});
8181

82+
auto sqrt_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
83+
{"aten::sqrt(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
84+
auto in = args[0].ITensorOrFreeze(ctx);
85+
if (in->getType() == nvinfer1::DataType::kINT32) {
86+
// unary sqrt layer only supports float inputs
87+
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT, util::node_info(n).c_str());
88+
}
89+
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kSQRT);
90+
TORCHTRT_CHECK(unary_layer, "Unable to create sqrt layer from node: " << *n);
91+
unary_layer->setName(util::node_info(n).c_str());
92+
unary_layer->setOutputType(0, in->getType());
93+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
94+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
95+
return true;
96+
}});
97+
8298
auto isfinite_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
8399
{"aten::isfinite(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
84100
auto in = args[0].ITensorOrFreeze(ctx);
@@ -126,7 +142,6 @@ convert(atan, kATAN);
126142
convert(floor, kFLOOR);
127143
convert(log, kLOG);
128144
convert(ceil, kCEIL);
129-
convert(sqrt, kSQRT);
130145
convert(exp, kEXP);
131146
convert(neg, kNEG);
132147
convert(erf, kERF);

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
146146
if (lower_info.converting_to_trt_engine) {
147147
passes::RemoveCollectionCast(g);
148148
}
149+
passes::UnpackScaledDotProductAttention(g);
149150
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
150151
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
151152
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cc_library(
3838
"unpack_hardswish.cpp",
3939
"unpack_log_softmax.cpp",
4040
"unpack_rsqrt.cpp",
41+
"unpack_scaled_dot_product_attention.cpp",
4142
"unpack_std.cpp",
4243
"unpack_var.cpp",
4344
"view_to_reshape.cpp",

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
4949
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
5050
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
5151
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
52+
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph);
5253
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
5354
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
5455
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
namespace torch_tensorrt {
8+
namespace core {
9+
namespace lowering {
10+
namespace passes {
11+
12+
// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
13+
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
14+
std::string sdpa_pattern = R"IR(
15+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
16+
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal)
17+
return (%out))IR";
18+
19+
std::string unpacked_sdpa_pattern = R"IR(
20+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
21+
%none : NoneType = prim::Constant()
22+
%1 : int = prim::Constant[value=-1]()
23+
%2 : int = prim::Constant[value=-2]()
24+
%3 : int = aten::size(%query, %1)
25+
%q_size : Long() = prim::NumToTensor(%3)
26+
%sqrt : Tensor = aten::sqrt(%q_size)
27+
%scale_factor : Tensor = aten::reciprocal(%sqrt)
28+
%key_transpose : Tensor = aten::transpose(%key, %2, %1)
29+
%matmul : Tensor = aten::matmul(%query, %key_transpose)
30+
%attn_weight : Tensor = aten::mul(%matmul, %scale_factor)
31+
%softmax : Tensor = aten::softmax(%attn_weight, %1, %none)
32+
%out : Tensor = aten::matmul(%softmax, %value)
33+
return(%out))IR";
34+
35+
std::string unpacked_sdpa_attn_biased_pattern = R"IR(
36+
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal):
37+
%none : NoneType = prim::Constant()
38+
%0 : int = prim::Constant[value=1]()
39+
%1 : int = prim::Constant[value=-1]()
40+
%2 : int = prim::Constant[value=-2]()
41+
%3 : int = aten::size(%query, %1)
42+
%q_size : Long() = prim::NumToTensor(%3)
43+
%sqrt : Tensor = aten::sqrt(%q_size)
44+
%scale_factor : Tensor = aten::reciprocal(%sqrt)
45+
%key_transpose : Tensor = aten::transpose(%key, %2, %1)
46+
%matmul : Tensor = aten::matmul(%query, %key_transpose)
47+
%attn_weight : Tensor = aten::mul(%matmul, %scale_factor)
48+
%attn_bias : Tensor = trt::attn_bias_from_attn_mask(%attn_mask)
49+
%attn_weight_with_bias : Tensor = aten::add(%attn_weight, %attn_bias, %0)
50+
%softmax : Tensor = aten::softmax(%attn_weight_with_bias, %1, %none)
51+
%out : Tensor = aten::matmul(%softmax, %value)
52+
return(%out))IR";
53+
54+
// rewrite with None attn_mask
55+
torch::jit::SubgraphRewriter sdpa_rewriter;
56+
sdpa_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_pattern);
57+
sdpa_rewriter.runOnGraph(
58+
graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) {
59+
auto is_causal_node = match.anchor->inputs().at(5)->node();
60+
if (is_causal_node->kind() != at::prim::Constant) {
61+
LOG_WARNING("Could not unpack scaled_dot_product_attention with non constant is_causal: " << *is_causal_node);
62+
return false;
63+
}
64+
if (is_causal_node->i(at::attr::value) == 1) {
65+
LOG_WARNING("Could not unpack scaled_dot_product_attention with is_causal = True: " << *is_causal_node);
66+
return false;
67+
}
68+
auto attn_mask_node = match.anchor->inputs().at(3)->node();
69+
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
70+
return false;
71+
}
72+
return true;
73+
});
74+
75+
// rewrite with float/bool attn_mask this uses a custom op to implement the divergent behavior between bool and float
76+
// masks without a conditional
77+
torch::jit::SubgraphRewriter sdpa_attn_mask_rewriter;
78+
sdpa_attn_mask_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_attn_biased_pattern);
79+
sdpa_attn_mask_rewriter.runOnGraph(
80+
graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) {
81+
auto is_causal_node = match.anchor->inputs().at(5)->node();
82+
if (is_causal_node->kind() != at::prim::Constant || is_causal_node->i(at::attr::value) == 1) {
83+
// messages already written in first pass, do not write again
84+
return false;
85+
}
86+
return true;
87+
});
88+
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);
89+
}
90+
91+
} // namespace passes
92+
} // namespace lowering
93+
} // namespace core
94+
} // namespace torch_tensorrt

core/lowering/register_trt_placeholder_ops.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <limits>
12
#include "torch/csrc/jit/runtime/custom_operator.h"
23

34
namespace torch {
@@ -14,6 +15,17 @@ RegisterOperators trt_placeholder_ops_reg({
1415
"trt::const(Tensor val) -> Tensor",
1516
[](Stack& stack) { /*noop*/ },
1617
aliasAnalysisFromSchema()),
18+
Operator(
19+
"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor",
20+
[](Stack& stack) {
21+
auto attn_mask = pop(stack).to<at::Tensor>();
22+
if (attn_mask.scalar_type() == at::kBool) {
23+
attn_mask = attn_mask;
24+
attn_mask.masked_fill_(attn_mask.logical_not(), -std::numeric_limits<float>::infinity());
25+
}
26+
return attn_mask;
27+
},
28+
c10::AliasAnalysisKind::CONSERVATIVE),
1729
});
1830

1931
} // namespace jit

cpp/include/torch_tensorrt/macros.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#define STR(x) XSTR(x)
2525

2626
#define TORCH_TENSORRT_MAJOR_VERSION 2
27-
#define TORCH_TENSORRT_MINOR_VERSION 2
27+
#define TORCH_TENSORRT_MINOR_VERSION 3
2828
#define TORCH_TENSORRT_PATCH_VERSION 0
2929
#define TORCH_TENSORRT_VERSION \
3030
STR(TORCH_TENSORRT_MAJOR_VERSION) \

dev_dep_versions.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__: "2.2.0.dev0"
1+
__version__: "2.3.0.dev0"
22
__cuda_version__: "12.1"
33
__cudnn_version__: "8.9"
44
__tensorrt_version__: "8.6"

docs/_cpp_api/classtorch__tensorrt_1_1DataType.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class DataType &mdash; Torch-TensorRT v2.2.0.dev0+c7a26fa documentation</title>
13+
<title>Class DataType &mdash; Torch-TensorRT v2.3.0.dev0+85971ff documentation</title>
1414

1515

1616

@@ -237,7 +237,7 @@
237237

238238

239239
<div class="version">
240-
v2.2.0.dev0+c7a26fa
240+
v2.3.0.dev0+85971ff
241241
</div>
242242

243243

docs/_cpp_api/classtorch__tensorrt_1_1Device_1_1DeviceType.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class Device::DeviceType &mdash; Torch-TensorRT v2.2.0.dev0+c7a26fa documentation</title>
13+
<title>Class Device::DeviceType &mdash; Torch-TensorRT v2.3.0.dev0+85971ff documentation</title>
1414

1515

1616

@@ -237,7 +237,7 @@
237237

238238

239239
<div class="version">
240-
v2.2.0.dev0+c7a26fa
240+
v2.3.0.dev0+85971ff
241241
</div>
242242

243243

docs/_cpp_api/classtorch__tensorrt_1_1TensorFormat.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Class TensorFormat &mdash; Torch-TensorRT v2.2.0.dev0+c7a26fa documentation</title>
13+
<title>Class TensorFormat &mdash; Torch-TensorRT v2.3.0.dev0+85971ff documentation</title>
1414

1515

1616

@@ -237,7 +237,7 @@
237237

238238

239239
<div class="version">
240-
v2.2.0.dev0+c7a26fa
240+
v2.3.0.dev0+85971ff
241241
</div>
242242

243243

docs/_cpp_api/classtorch__tensorrt_1_1ptq_1_1Int8CacheCalibrator.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
<meta name="viewport" content="width=device-width, initial-scale=1.0">
1212

13-
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v2.2.0.dev0+c7a26fa documentation</title>
13+
<title>Template Class Int8CacheCalibrator &mdash; Torch-TensorRT v2.3.0.dev0+85971ff documentation</title>
1414

1515

1616

@@ -237,7 +237,7 @@
237237

238238

239239
<div class="version">
240-
v2.2.0.dev0+c7a26fa
240+
v2.3.0.dev0+85971ff
241241
</div>
242242

243243

0 commit comments

Comments
 (0)