Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Remove linear lowering pass and converter #3323

Merged

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Dec 13, 2024

1. The linear lowering pass only works with ir="torch_compile" along with nn.Linear(bias=True), and has no effect at all with ir="dynamo" no matter bias is True or False. Using codes below:

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(20, 30, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


with torch.inference_mode():
    model = MyModule().eval().cuda().half()
    inputs = [torch.randn(128, 20, dtype=torch.half, device="cuda")]

    trt_model = torch_tensorrt.compile(
        model, "torch_compile", inputs, enabled_precisions={torch.half}, debug=True, min_block_size=1
    )

    trt_model(*inputs)

ir="torch_compile", bias=True

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.lower_linear:Graph after lowering linear:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.linear.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="torch_compile", bias=False

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.mm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="dynamo", bias=True

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    return (linear,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="dynamo", bias=False

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight), kwargs = {})
    return (linear,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=1] = placeholder[target=x]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.mm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

2. The linear converter's performace is basically the same as without using the converter. Using codes below:

from __future__ import annotations

import os

import numpy as np
import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

times = 100


@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[torch.Tensor]) -> np.ndarray:
    # Warm up
    for i in range(3):
        model(inputs[i])

    torch.cuda.synchronize()

    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]

    for i in range(times):
        torch.cuda._sleep(1_000_000)

        start_events[i].record()
        model(inputs[i])
        end_events[i].record()

    torch.cuda.synchronize()

    timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    return np.array(timings)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(4096, 8192)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


torch.manual_seed(12345)

model = MyModule().eval().cuda().half()

inputs = [torch_tensorrt.Input((2048, 4096), dtype=torch.half)]

trt_model = torch_tensorrt.compile(
    model, "torch_compile", inputs, enabled_precisions={torch.half}, debug=False, min_block_size=1
)

inputs = [torch.randn(2048, 4096, dtype=torch.half, device="cuda") for _ in range(times)]

timing = benchmark(trt_model, inputs)

print("\nTiming:")
print(f"Min={timing.min()} ms, Mean={timing.mean()} ms, Max={timing.max()} ms")

torch._dynamo.reset()

ir="torch_compile", with linear pass

Timing:
Min=2.2077438831329346 ms, Mean=2.27082528591156 ms, Max=3.074399948120117 ms

ir="torch_compile", linear pass removed

Timing:
Min=2.20467209815979 ms, Mean=2.2625676822662353 ms, Max=2.8375039100646973 ms

ir="dynamo"

Timing:
Min=2.0244479179382324 ms, Mean=2.063061113357544 ms, Max=2.6060800552368164 ms

3. TestLowerLinear is flaky. Recently TestLinearConverter is also observed to have threshold failures on Windows CI.

For reasons mentioned above, I think the linear lowering pass and converter provide no benefit and should be removed.

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Dec 13, 2024
@github-actions github-actions bot requested a review from peri044 December 13, 2024 17:20
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great and thanks for the perf analysis. LGTM

@peri044
Copy link
Collaborator

peri044 commented Dec 16, 2024

@HolyWu can you rebase ?

@peri044 peri044 merged commit d7071ba into pytorch:main Dec 17, 2024
68 checks passed
@HolyWu HolyWu deleted the remove_linear_lowering_pass_and_converter branch December 18, 2024 11:37
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants