Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

There are issues with support for ConvTranspose2d #946

Open
mortal-Zero opened this issue Sep 5, 2024 · 2 comments
Open

There are issues with support for ConvTranspose2d #946

mortal-Zero opened this issue Sep 5, 2024 · 2 comments

Comments

@mortal-Zero
Copy link

Hello, and thank you for your outstanding project.
I encountered an error when converting a structure containing ConvTranspose2d using torch2trt. Here is the code and the error.

import torch
import torch.nn as nn
from torch2trt import torch2trt

model = nn.Sequential(
    nn.ConvTranspose2d(in_channels=32, out_channels=64,
                       kernel_size=4, stride=2,
                       padding=1, bias=True),
    nn.BatchNorm2d(num_features=64),
    nn.LeakyReLU()
)
model.to("cuda:0").eval()
x = torch.zeros([1, 32, 16, 16]).to("cuda:0")
y = model(x)
print("=====>> input: {} || output: {}".format(x.shape, y.shape))
model_trt = torch2trt(model, [x])
=====>> input: torch.Size([1, 32, 16, 16]) || output: torch.Size([1, 64, 32, 32])
[09/05/2024-11:09:34] [TRT] [E] 3: 0:0:DECONVOLUTION:GPU:kernel weights has count 32768 but 16384 was expected
[09/05/2024-11:09:34] [TRT] [E] 4: 0:0:DECONVOLUTION:GPU: count of 32768 weights in kernel, but kernel dimensions (4,4) with 32 input channels, 32 output channels and 1 groups were specified. Expected Weights count is 32 * 4*4 * 32 / 1 = 16384
[09/05/2024-11:09:34] [TRT] [E] 4: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2212] Error Code 4: Internal Error (0:0:DECONVOLUTION:GPU: output shape can not be computed)
[09/05/2024-11:09:34] [TRT] [E] 3: [network.cpp::addScaleNd::1162] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addScaleNd::1162, condition: qdqScale || basicScale
)
Traceback (most recent call last):
  File "/workspace/baiyixuan/test_cvcuda/digitalhuman_service/debug_codes/test.py", line 30, in <module>
    model_trt = torch2trt(model, [x])
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 643, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
    return F.batch_norm(
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/torch2trt.py", line 262, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/cvcuda/lib/python3.10/site-packages/torch2trt-0.5.0-py3.10-linux-x86_64.egg/torch2trt/converters/native_converters.py", line 183, in convert_batch_norm
    output._trt = layer.get_output(0)
AttributeError: 'NoneType' object has no attribute 'get_output'

Looking forward to your reply.

@mortal-Zero
Copy link
Author

Oh yes, I can execute the following code correctly.

import torch
import torch.nn as nn
from torch2trt import torch2trt

model = nn.Sequential(
    nn.Conv2d(in_channels=6, out_channels=32,
              kernel_size=3, stride=1,
              padding=1, bias=True),
    nn.BatchNorm2d(num_features=32),
    nn.LeakyReLU(),
)
model.to("cuda:0").eval()
x = torch.zeros([1, 6, 96, 96]).to("cuda:0")
y = model(x)
print("=====>> input: {} || output: {}".format(x.shape, y.shape))
model_trt = torch2trt(model, [x])

@fasogbon
Copy link

Any fix please? I am having same problem

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants