Skip to content

Commit

Permalink
[Fix]: fix data type in fused-bias-leakyrelu for apex fp16 training (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nbei authored Apr 24, 2021
1 parent 9649a9a commit 841a078
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions mmcv/ops/fused_bias_leakyrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@ def backward(ctx, gradgrad_input, gradgrad_bias):
# The second order deviation, in fact, contains two parts, while the
# the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input,
gradgrad_bias, out, 3,
1, ctx.negative_slope,
ctx.scale)
gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_input, gradgrad_bias.to(out.dtype), out, 3, 1,
ctx.negative_slope, ctx.scale)

return gradgrad_out, None, None, None

Expand Down Expand Up @@ -139,7 +138,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
if not input.is_cuda:
return bias_leakyrelu_ref(input, bias, negative_slope, scale)

return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
negative_slope, scale)


def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
Expand Down

0 comments on commit 841a078

Please # to comment.