-
-
Notifications
You must be signed in to change notification settings - Fork 30
Fix AD for parameters #175
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,37 +3,47 @@ using Integrals | |
if isdefined(Base, :get_extension) | ||
using Zygote | ||
import ChainRulesCore | ||
import ChainRulesCore: NoTangent | ||
import ChainRulesCore: NoTangent, ProjectTo | ||
else | ||
using ..Zygote | ||
import ..Zygote.ChainRulesCore | ||
import ..Zygote.ChainRulesCore: NoTangent | ||
import ..Zygote.ChainRulesCore: NoTangent, ProjectTo | ||
end | ||
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) | ||
ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99 | ||
|
||
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, | ||
p; | ||
kwargs...) | ||
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) | ||
|
||
# the adjoint will be the integral of the input sensitivities, so it maps the | ||
# sensitivity of the output to an object of the type of the parameters | ||
function quadrature_adjoint(Δ) | ||
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ | ||
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes | ||
y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar | ||
# this will not be type-stable, but I believe it is unavoidable due to two ambiguities: | ||
# 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the | ||
# output of the algorithm must be a scalar or a vector of length 1 | ||
# 2. when nout = 1 the integrand can either be a scalar or a vector of length 1 | ||
if isinplace(cache) | ||
dx = zeros(cache.nout) | ||
_f = x -> cache.f(dx, x, p) | ||
if sensealg.vjp isa Integrals.ZygoteVJP | ||
dfdp = function (dx, x, p) | ||
_, back = Zygote.pullback(p) do p | ||
_dx = Zygote.Buffer(x, cache.nout, size(x, 2)) | ||
z, back = Zygote.pullback(p) do p | ||
_dx = cache.nout == 1 ? | ||
Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) : | ||
Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x))) | ||
cache.f(_dx, x, p) | ||
copy(_dx) | ||
end | ||
|
||
z = zeros(size(x, 2)) | ||
for idx in 1:size(x, 2) | ||
z[1] = 1 | ||
dx[:, idx] = back(z)[1] | ||
z[idx] = 0 | ||
z .= zero(eltype(z)) | ||
for idx in 1:size(x, ndims(x)) | ||
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) | ||
dx[:, idx] .= back(z)[1] | ||
z isa Vector ? (z[idx] = zero(eltype(z))) : | ||
(z[:, idx] .= zero(eltype(z))) | ||
end | ||
end | ||
elseif sensealg.vjp isa Integrals.ReverseDiffVJP | ||
|
@@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal | |
if sensealg.vjp isa Integrals.ZygoteVJP | ||
if cache.batch > 0 | ||
dfdp = function (x, p) | ||
_, back = Zygote.pullback(p -> cache.f(x, p), p) | ||
z, back = Zygote.pullback(p -> cache.f(x, p), p) | ||
# messy, there are 4 cases, some better in forward mode than reverse | ||
# 1: length(y) == 1 and length(p) == 1 | ||
# 2: length(y) > 1 and length(p) == 1 | ||
# 3: length(y) == 1 and length(p) > 1 | ||
# 4: length(y) > 1 and length(p) > 1 | ||
Comment on lines
+58
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what this comment is here for? I mean, agreed these are the 4 cases and sometimes forward is better than reverse, but I don't understand why that's here 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, I just left that as a note to myself and I can remove it. The lines of code below need all the if statements to handle these cases, so I found it helpful to list them |
||
|
||
out = zeros(length(p), size(x, 2)) | ||
z = zeros(size(x, 2)) | ||
for idx in 1:size(x, 2) | ||
z[idx] = 1 | ||
out[:, idx] = back(z)[1] | ||
z[idx] = 0 | ||
z .= zero(eltype(z)) | ||
out = zeros(eltype(p), size(p)..., size(x, ndims(x))) | ||
for idx in 1:size(x, ndims(x)) | ||
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y) | ||
out isa Vector ? (out[idx] = back(z)[1]) : | ||
(out[:, idx] .= back(z)[1]) | ||
z isa Vector ? (z[idx] = zero(y)) : | ||
(z[:, idx] .= zero(eltype(y))) | ||
end | ||
out | ||
end | ||
|
@@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal | |
do_inf_transformation = Val(false), | ||
cache.kwargs...) | ||
|
||
if p isa Number | ||
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1] | ||
else | ||
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u | ||
end | ||
project_p = ProjectTo(p) | ||
dp = project_p(Integrals.__solvebp_call(dp_cache, | ||
alg, | ||
sensealg, | ||
lb, | ||
ub, | ||
p; | ||
kwargs...).u) | ||
|
||
if lb isa Number | ||
dlb = -_f(lb) | ||
dub = _f(ub) | ||
dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb) | ||
dub = cache.batch > 0 ? _f([ub]) : _f(ub) | ||
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp) | ||
else | ||
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we | ||
# can see from writing the multidimensional integral as an iterated integral | ||
# alternatively we can use Stokes' theorem to replace the integral on the | ||
# boundary with a volume integral of the flux of the integrand | ||
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the | ||
# dimensionality of the integral or the quadrature used (such as quadratures | ||
# that don't evaluate points on the boundaries) and it could be generalized to | ||
# other kinds of domains. The only question is to determine ω in terms of f and | ||
# the deformation of the surface (e.g. consider integral over an ellipse and | ||
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse) | ||
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), | ||
NoTangent(), dp) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't the nout = 1 case not be able to use
similar
because it could be scalar?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
p
is an array (see function signature), so arerawp
(look below) anddualp
, so I thinksimilar
will be defined.