diff --git a/src/controlnet_aux/pidi/model.py b/src/controlnet_aux/pidi/model.py index 16595b3..4522f2a 100644 --- a/src/controlnet_aux/pidi/model.py +++ b/src/controlnet_aux/pidi/model.py @@ -330,10 +330,7 @@ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): padding = 2 * dilation shape = weights.shape - if weights.is_cuda: - buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) - else: - buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) + buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) weights = weights.view(shape[0], shape[1], -1) buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]