Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Coverage] RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in SwinTransformer block #3257

Open
Tracked by #3179
chohk88 opened this issue Oct 22, 2024 · 0 comments

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Oct 22, 2024

While running a SwinTransformer block using torch-tensorrt, there is a RuntimeError (torch.ops.aten.expand.default) during matrix multiplication in the attention mechanism. Specifically, the issue arises during the calculation of attention scores using q @ k.transpose(-2, -1) in the forward pass.

The following error message is produced:

Traceback (most recent call last):
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/backend/backends.py", line 114, in _pretraced_backend
    trt_compiled = compile_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 487, in compile_module
    trt_module = convert_module(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
    interpreter_result = interpret_module_to_result(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 617, in run
    self._construct_trt_network_def()
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 348, in _construct_trt_network_def
    super().run()
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 683, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
  File "/root/.pyenv/versions/3.10.15/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 792, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 539, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1170, in aten_ops_expand
    return impl.slice.expand(
  File "/opt/torch_tensorrt/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 240, in expand
    raise RuntimeError(
RuntimeError: expand called with 4-dimensional shape on Tensor with 4 dimensions. Cannot expand to shape with rank smaller than original tensor.

While executing %expand : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%mul, [1, 4, 49, 16]), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7ff359701cf0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215430>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582155b0>: ((64,), torch.float32, True, (1,), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207db0>: ((1, 49, 64), torch.float32, False, (3136, 64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358207d30>: ((192, 64), torch.float32, True, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358204ef0>: ((64, 192), torch.float32, False, (1, 64), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358206c70>: ((49, 64), torch.float32, False, (64, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582059f0>: ((49, 192), torch.float32, False, (192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358205af0>: ((1, 49, 192), torch.float32, False, (9408, 192, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215cf0>: ((1, 49, 3, 4, 16), torch.float32, False, (9408, 192, 64, 16, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358215df0>: ((3, 1, 4, 49, 16), torch.float32, False, (64, 9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216130>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582161f0>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216270>: ((1, 4, 49, 16), torch.float32, False, (9408, 16, 192, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff358216630>: ((1, 4, 49, 16), torch.float32, False, (3136, 16, 64, 1), None, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff3582168b0>: ((1, 4, 16, 49), torch.float32, False, (9408, 16, 1, 192), None, False, {})}})

Reproduction Code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt

# Set device and backend
backend = "torch_tensorrt"
device = torch.device("cuda:0")

class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # Error happens here
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return self.proj_drop(x)

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))
        return x

# Example input and usage
dim = 64
num_heads = 4
window_size = 7

x = torch.randn(1, 49, dim).to(device)  # Example input (B, N, C)
block = SwinTransformerBlock(dim, num_heads, window_size)
block.eval()
model = block.to(device)

# Forward pass through block
block = torch.compile(
    block,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "enabled_precisions": {torch.float16, torch.float32},
        "device": device,
        "min_block_size": 5,
        "require_full_compilation": True
    },
    dynamic=False,
)

outputs_after = model(x)  # Error occurs here
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant