Skip to content

Commit f5ed85b

Browse files
port batchnorm rrule from Flux (#499)
1 parent e5cff84 commit f5ed85b

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/normalization.jl

+10
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,13 @@
22
function batchnorm end
33

44
function ∇batchnorm end
5+
6+
7+
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
8+
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
9+
function batchnorm_pullback(Δ)
10+
grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...)
11+
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
12+
end
13+
y, batchnorm_pullback
14+
end

0 commit comments

Comments
 (0)