Skip to content

🐛 [Bug] Cannot perform inference if the ExportedProgram has weighted layers and custom ops. #2576

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
Tracked by #2262
peri044 opened this issue Jan 5, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@peri044
Copy link
Collaborator

peri044 commented Jan 5, 2024

Bug Description

  1. The graph has a conv node in pytorch and a TensorRT node. The conv node has weight and bias lifted as placeholders. Hence we are seeing this runtime error of mismatch in the number of inputs.

Error message:

_check_input_constraints_for_graph(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 48, in _check_input_constraints_for_graph
    check(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 40, in check
    raise RuntimeError(msg)
RuntimeError: Unexpected number of inputs (expected 3, got 1)
  1. If we unlift these parameters (i.e conv_weight and conv_bias are registered as get_attr nodes), there's a different error GraphModule does not contain attribute conv_weight
    Reason:
    This is because - syntax error occurs in _create_graph_module_for_export and hence the resulting gm does not have these attributes.

To Reproduce

Install the nightly version of Torch-TRT

pip install --pre torch-tensorrt  --extra-index-url https://download.pytorch.org/whl/nightly/cu121

Run the following script to reproduce the error

import torch
import torch_tensorrt
import unittest

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        conv = self.conv(x)
        relu = self.relu(conv)
        mul = relu * 0.5
        return mul

input = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
model = MyModule().eval().cuda()

compile_spec = {
        "inputs": [
            torch_tensorrt.Input(
                input.shape, dtype=torch.float, format=torch.contiguous_format
            )
        ],
        "ir": "dynamo",
        "min_block_size": 1,
        "torch_executed_ops": {"torch.ops.aten.convolution.default"},
    }

exp_program = torch_tensorrt.dynamo.trace(model, **compile_spec)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, [input], ir="exported_program")

torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@peri044
Copy link
Collaborator Author

peri044 commented Jan 5, 2024

Related pytorch issue : pytorch/pytorch#116831

@peri044
Copy link
Collaborator Author

peri044 commented Jan 16, 2024

This is fixed by #2575

@peri044 peri044 closed this as completed Apr 16, 2024
@Hong753
Copy link

Hong753 commented Nov 29, 2024

Hello, I am having this same issue on torch_tensorrt 2.5.0+cu118, while compiling with a custom CUDA extension.

RuntimeError: Unexpected number of inputs (expected 9, got 7)

How am I supposed to fix this?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants