Skip to content

Fix AD for parameters #175

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 3 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
dfdp = function (out, x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
if cache.batch > 0
dx = similar(dualp, cache.nout, size(x, 2))
dx = cache.nout == 1 ? similar(dualp, size(x, ndims(x))) :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't the nout = 1 case not be able to use similar because it could be scalar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since p is an array (see function signature), so are rawp (look below) and dualp, so I think similar will be defined.

similar(dualp, cache.nout, size(x, ndims(x)))
else
dx = similar(dualp, cache.nout)
end
Expand All @@ -49,7 +50,7 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
ys = cache.f(x, dualp)
if cache.batch > 0
out = similar(p, V, nout, size(x, 2))
out = similar(p, V, nout, size(x, ndims(x)))
else
out = similar(p, V, nout)
end
Expand Down
80 changes: 55 additions & 25 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,47 @@ using Integrals
if isdefined(Base, :get_extension)
using Zygote
import ChainRulesCore
import ChainRulesCore: NoTangent
import ChainRulesCore: NoTangent, ProjectTo
else
using ..Zygote
import ..Zygote.ChainRulesCore
import ..Zygote.ChainRulesCore: NoTangent
import ..Zygote.ChainRulesCore: NoTangent, ProjectTo
end
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)
ChainRulesCore.@non_differentiable Integrals.isinplace(f, n) # fixes #99

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub,
p;
kwargs...)
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)

# the adjoint will be the integral of the input sensitivities, so it maps the
# sensitivity of the output to an object of the type of the parameters
function quadrature_adjoint(Δ)
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
# https://juliadiff.org/ChainRulesCore.jl/dev/design/many_tangents.html#manytypes
y = cache.nout == 1 ? Δ[1] : Δ # interpret the output as scalar
# this will not be type-stable, but I believe it is unavoidable due to two ambiguities:
# 1. Δ is the output of the algorithm, and when nout = 1 it is undefined whether the
# output of the algorithm must be a scalar or a vector of length 1
# 2. when nout = 1 the integrand can either be a scalar or a vector of length 1
if isinplace(cache)
dx = zeros(cache.nout)
_f = x -> cache.f(dx, x, p)
if sensealg.vjp isa Integrals.ZygoteVJP
dfdp = function (dx, x, p)
_, back = Zygote.pullback(p) do p
_dx = Zygote.Buffer(x, cache.nout, size(x, 2))
z, back = Zygote.pullback(p) do p
_dx = cache.nout == 1 ?
Zygote.Buffer(dx, eltype(y), size(x, ndims(x))) :
Zygote.Buffer(dx, eltype(y), cache.nout, size(x, ndims(x)))
cache.f(_dx, x, p)
copy(_dx)
end

z = zeros(size(x, 2))
for idx in 1:size(x, 2)
z[1] = 1
dx[:, idx] = back(z)[1]
z[idx] = 0
z .= zero(eltype(z))
for idx in 1:size(x, ndims(x))
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
dx[:, idx] .= back(z)[1]
z isa Vector ? (z[idx] = zero(eltype(z))) :
(z[:, idx] .= zero(eltype(z)))
end
end
elseif sensealg.vjp isa Integrals.ReverseDiffVJP
Expand All @@ -44,14 +54,21 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
if sensealg.vjp isa Integrals.ZygoteVJP
if cache.batch > 0
dfdp = function (x, p)
_, back = Zygote.pullback(p -> cache.f(x, p), p)
z, back = Zygote.pullback(p -> cache.f(x, p), p)
# messy, there are 4 cases, some better in forward mode than reverse
# 1: length(y) == 1 and length(p) == 1
# 2: length(y) > 1 and length(p) == 1
# 3: length(y) == 1 and length(p) > 1
# 4: length(y) > 1 and length(p) > 1
Comment on lines +58 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what this comment is here for? I mean, agreed these are the 4 cases and sometimes forward is better than reverse, but I don't understand why that's here 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I just left that as a note to myself and I can remove it. The lines of code below need all the if statements to handle these cases, so I found it helpful to list them


out = zeros(length(p), size(x, 2))
z = zeros(size(x, 2))
for idx in 1:size(x, 2)
z[idx] = 1
out[:, idx] = back(z)[1]
z[idx] = 0
z .= zero(eltype(z))
out = zeros(eltype(p), size(p)..., size(x, ndims(x)))
for idx in 1:size(x, ndims(x))
z isa Vector ? (z[idx] = y) : (z[:, idx] .= y)
out isa Vector ? (out[idx] = back(z)[1]) :
(out[:, idx] .= back(z)[1])
z isa Vector ? (z[idx] = zero(y)) :
(z[:, idx] .= zero(eltype(y)))
end
out
end
Expand All @@ -76,17 +93,30 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal
do_inf_transformation = Val(false),
cache.kwargs...)

if p isa Number
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
else
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u
end
project_p = ProjectTo(p)
dp = project_p(Integrals.__solvebp_call(dp_cache,
alg,
sensealg,
lb,
ub,
p;
kwargs...).u)

if lb isa Number
dlb = -_f(lb)
dub = _f(ub)
dlb = cache.batch > 0 ? -_f([lb]) : -_f(lb)
dub = cache.batch > 0 ? _f([ub]) : _f(ub)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
else
# we need to compute 2*length(lb) integrals on the faces of the hypercube, as we
# can see from writing the multidimensional integral as an iterated integral
# alternatively we can use Stokes' theorem to replace the integral on the
# boundary with a volume integral of the flux of the integrand
# ∫∂Ω ω = ∫Ω dω, which would be better since we won't have to change the
# dimensionality of the integral or the quadrature used (such as quadratures
# that don't evaluate points on the boundaries) and it could be generalized to
# other kinds of domains. The only question is to determine ω in terms of f and
# the deformation of the surface (e.g. consider integral over an ellipse and
# asking for the derivative of the result w.r.t. the semiaxes of the ellipse)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent(), dp)
end
Expand Down
54 changes: 16 additions & 38 deletions lib/IntegralsCubature/src/IntegralsCubature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
maxiters = typemax(Int))
nout = prob.nout
if nout == 1
# the output of prob.f could be either scalar or a vector of length 1, however
# the behavior of the output of the integration routine is undefined (could differ
# across algorithms)
# Cubature will output a real number in when called without nout/fdim
if prob.batch == 0
if isinplace(prob)
dx = zeros(eltype(lb), prob.nout)
Expand All @@ -63,74 +67,52 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
if lb isa Number
if alg isa CubatureJLh
_val, err = Cubature.hquadrature(f, lb, ub;
val, err = Cubature.hquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pquadrature(f, lb, ub;
val, err = Cubature.pquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
val = prob.f(lb, p) isa Number ? _val : [_val]
else
if alg isa CubatureJLh
_val, err = Cubature.hcubature(f, lb, ub;
val, err = Cubature.hcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pcubature(f, lb, ub;
val, err = Cubature.pcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end

if isinplace(prob) || !isa(prob.f(lb, p), Number)
val = [_val]
else
val = _val
end
end
else
if isinplace(prob)
f = (x, dx) -> prob.f(dx', x, p)
elseif lb isa Number
if prob.f([lb ub], p) isa Vector
f = (x, dx) -> (dx .= prob.f(x', p))
else
f = function (x, dx)
dx[:] = prob.f(x', p)
end
end
f = (x, dx) -> prob.f(dx, x, p)
else
if prob.f([lb ub], p) isa Vector
f = (x, dx) -> (dx .= prob.f(x, p))
else
f = function (x, dx)
dx .= prob.f(x, p)[:]
end
end
f = (x, dx) -> (dx .= prob.f(x, p))
end
if lb isa Number
if alg isa CubatureJLh
_val, err = Cubature.hquadrature_v(f, lb, ub;
val, err = Cubature.hquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pquadrature_v(f, lb, ub;
val, err = Cubature.pquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
_val, err = Cubature.hcubature_v(f, lb, ub;
val, err = Cubature.hcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
_val, err = Cubature.pcubature_v(f, lb, ub;
val, err = Cubature.pcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
end
val = _val isa Number ? [_val] : _val
end
else
if prob.batch == 0
Expand Down Expand Up @@ -166,13 +148,9 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
else
if isinplace(prob)
f = (x, dx) -> prob.f(dx, x, p)
f = (x, dx) -> (prob.f(dx, x, p); dx)
else
if lb isa Number
f = (x, dx) -> (dx .= prob.f(x', p))
else
f = (x, dx) -> (dx .= prob.f(x, p))
end
f = (x, dx) -> (dx .= prob.f(x, p))
end

if lb isa Number
Expand Down
26 changes: 13 additions & 13 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Integrals, Zygote, FiniteDiff, ForwardDiff, SciMLSensitivity
using Integrals, Zygote, FiniteDiff, ForwardDiff#, SciMLSensitivity
using IntegralsCuba, IntegralsCubature
using Test

Expand Down Expand Up @@ -117,7 +117,7 @@ dp4 = ForwardDiff.gradient(p -> testf(lb, ub, p), p)
@test dp1 ≈ dp4

### Batch Single dim
f(x, p) = x * p[1] .+ p[2] * p[3]
f(x, p) = x * p[1] .+ p[2] * p[3] # scalar integrand

lb = 1.0
ub = 3.0
Expand All @@ -130,14 +130,14 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1] # TODO fix: LoadError: DimensionMismatch("variable with size(x) == (1, 15) cannot have a gradient with size(dx) == (15,)")
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 ≈ dp3 #passes
@test_broken dp2 ≈ dp3 #passes
@test dp2 ≈ dp3 #passes

### Batch single dim, nout
f(x, p) = (x * p[1] .+ p[2] * p[3]) .* [1; 2]
f(x, p) = (x' * p[1] .+ p[2] * p[3]) .* [1; 2]

lb = 1.0
ub = 3.0
Expand All @@ -150,11 +150,11 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 ≈ dp3 #passes
# @test dp2 ≈ dp3 #passes
@test dp2 ≈ dp3 #passes

### Batch multi dim
f(x, p) = x[1, :] * p[1] .+ p[2] * p[3]
Expand Down Expand Up @@ -190,15 +190,15 @@ function testf3(lb, ub, p; f = f)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 ≈ dp3
# @test dp2 ≈ dp3
@test dp2 ≈ dp3

## iip Batch mulit dim
## iip Batch multi dim
function g(dx, x, p)
dx .= sum(x * p[1] .+ p[2] * p[3], dims = 1)
dx .= dropdims(sum(x * p[1] .+ p[2] * p[3], dims = 1), dims = 1)
end

lb = [1.0, 1.0]
Expand Down Expand Up @@ -236,8 +236,8 @@ function testf3(lb, ub, p; f = g)
end

dp1 = ForwardDiff.gradient(p -> testf3(lb, ub, p), p)
# dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp2 = Zygote.gradient(p -> testf3(lb, ub, p), p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p -> testf3(lb, ub, p), p)

@test dp1 ≈ dp3
# @test dp2 ≈ dp3
@test dp2 ≈ dp3
Loading