diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 5c45bd4ebb..13ff04ea09 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -70,6 +70,7 @@ def forward(ctx, ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step + weight = weight.type_as(input) ctx.save_for_backward(input, offset, weight) output = input.new_empty(