76
76
∇softmax! (Δ, x; dims= 1 ) = ∇softmax! (Δ, Δ, x, softmax (x, dims= dims); dims= dims)
77
77
∇softmax! (out, Δ, x; dims= 1 ) = ∇softmax! (out, Δ, x, softmax (x, dims= dims); dims= dims)
78
78
79
- function ChainRulesCore . rrule (:: typeof (softmax), xs; dims= 1 )
79
+ function rrule (:: typeof (softmax), xs; dims= 1 )
80
80
y = softmax (xs; dims= dims)
81
81
softmax_pullback (Δ) = (NO_FIELDS, ∇softmax (Δ, xs, y, dims= dims))
82
82
return y, softmax_pullback
112
112
113
113
∇logsoftmax (Δ:: AbstractArray{T} , x:: AbstractArray , y:: AbstractArray{S} ; dims = 1 ) where {T,S} =
114
114
∇logsoftmax! (similar (y, promote_type (T, S)), Δ, x, y; dims = dims)
115
-
115
+
116
116
# Old 2-arg version recomputing forward
117
117
∇logsoftmax (Δ, x; dims= 1 ) = ∇logsoftmax (Δ, x, logsoftmax (x, dims= dims); dims= dims)
118
118
∇logsoftmax! (Δ, x; dims= 1 ) = ∇logsoftmax! (Δ, Δ, x, logsoftmax (x, dims= dims); dims= dims)
@@ -123,7 +123,7 @@ function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
123
123
out .= Δ .- sum (Δ, dims = dims) .* exp .(y)
124
124
end
125
125
126
- function ChainRulesCore . rrule (:: typeof (logsoftmax), xs; dims= 1 )
126
+ function rrule (:: typeof (logsoftmax), xs; dims= 1 )
127
127
y = logsoftmax (xs; dims= dims)
128
128
logsoftmax_pullback (Δ) = (NO_FIELDS, ∇logsoftmax (Δ, xs, y, dims= dims))
129
129
return y, logsoftmax_pullback
0 commit comments