diff --git a/Project.toml b/Project.toml index 184f921..3033ef9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,19 @@ uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"] version = "0.1.0" +[deps] +ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + [compat] julia = "1" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "Zygote"] diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 93d7351..96e9dc5 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,5 +1,635 @@ module AbstractDifferentiation -# Write your package code here. +using LinearAlgebra, ExprTools + +export AD + +const AD = AbstractDifferentiation + +abstract type AbstractBackend end +abstract type AbstractFiniteDifference <: AbstractBackend end +abstract type AbstractForwardMode <: AbstractBackend end +abstract type AbstractReverseMode <: AbstractBackend end + +struct HigherOrderBackend{B} <: AbstractBackend + backends::B +end +reduceorder(b::AbstractBackend) = b +function reduceorder(b::HigherOrderBackend) + if length(b.backends)==1 + return lowest(b) # prevent zero tuple and subsequent error when reducing over HigherOrderBackend + else + return HigherOrderBackend(reverse(Base.tail(reverse(b.backends)))) + end +end +lowest(b::AbstractBackend) = b +lowest(b::HigherOrderBackend) = b.backends[end] +secondlowest(b::AbstractBackend) = b +secondlowest(b::HigherOrderBackend) = lowest(reduceorder(b)) + +# If the primal value is in y, extract it. +# Otherwise, re-compute it, e.g. in finite diff. +primalvalue(::AbstractFiniteDifference, ::Any, f, xs) = f(xs...) +primalvalue(::AbstractBackend, ys, ::Any, ::Any) = primalvalue(ys) +primalvalue(x::Tuple) = map(primalvalue, x) +primalvalue(x) = x + +function derivative(ab::AbstractBackend, f, xs::Number...) + der = getindex.(jacobian(lowest(ab), f, xs...), 1) + if der isa Tuple + return der + else + return (der,) + end +end + +function gradient(ab::AbstractBackend, f, xs...) + return reshape.(adjoint.(jacobian(lowest(ab), f, xs...)),size.(xs)) +end +function jacobian(ab::AbstractBackend, f, xs...) end +function jacobian(ab::HigherOrderBackend, f, xs...) + jacobian(lowest(ab), f, xs...) +end + +function hessian(ab::AbstractBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + @assert length(x) == 1 + x = x[1] + end + return jacobian(secondlowest(ab), x -> begin + gradient(lowest(ab), f, x)[1] # gradient returns a tuple + end, x) +end + +function value_and_derivative(ab::AbstractBackend, f, xs::Number...) + value, jacs = value_and_jacobian(lowest(ab), f, xs...) + return value[1], getindex.(jacs, 1) +end +function value_and_gradient(ab::AbstractBackend, f, xs...) + value, jacs = value_and_jacobian(lowest(ab), f, xs...) + return value, reshape.(adjoint.(jacs),size.(xs)) +end +function value_and_jacobian(ab::AbstractBackend, f, xs...) + local value + primalcalled = false + if lowest(ab) isa AbstractFiniteDifference + value = primalvalue(ab, nothing, f, xs) + primalcalled = true + end + jacs = jacobian(lowest(ab), (_xs...,) -> begin + v = f(_xs...) + if !primalcalled + value = primalvalue(ab, v, f, xs) + primalcalled = true + end + return v + end, xs...) + + return value, jacs +end +function value_and_hessian(ab::AbstractBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + @assert length(x) == 1 + x = x[1] + end + + local value + primalcalled = false + if ab isa AbstractFiniteDifference + value = primalvalue(ab, nothing, f, (x,)) + primalcalled = true + end + hess = jacobian(secondlowest(ab), _x -> begin + v, g = value_and_gradient(lowest(ab), f, _x) + if !primalcalled + value = primalvalue(ab, v, f, (x,)) + primalcalled = true + end + return g[1] # gradient returns a tuple + end, x) + return value, hess +end +function value_and_hessian(ab::HigherOrderBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + @assert length(x) == 1 + x = x[1] + end + local value + primalcalled = false + hess = jacobian(secondlowest(ab), (_x,) -> begin + v, g = value_and_gradient(lowest(ab), f, _x) + if !primalcalled + value = primalvalue(ab, v, f, (x,)) + primalcalled = true + end + return g[1] # gradient returns a tuple + end, x) + return value, hess +end +function value_gradient_and_hessian(ab::AbstractBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + @assert length(x) == 1 + x = x[1] + end + local value + primalcalled = false + grads, hess = value_and_jacobian(secondlowest(ab), _x -> begin + v, g = value_and_gradient(lowest(ab), f, _x) + if !primalcalled + value = primalvalue(secondlowest(ab), v, f, (x,)) + primalcalled = true + end + return g[1] # gradient returns a tuple + end, x) + return value, (grads,), hess +end +function value_gradient_and_hessian(ab::HigherOrderBackend, f, x) + if x isa Tuple + # only support computation of Hessian for functions with single input argument + @assert length(x) == 1 + x = x[1] + end + local value + primalcalled = false + grads, hess = value_and_jacobian(secondlowest(ab), _x -> begin + v, g = value_and_gradient(lowest(ab), f, _x) + if !primalcalled + value = primalvalue(secondlowest(ab), v, f, (x,)) + primalcalled = true + end + return g[1] # gradient returns a tuple + end, x) + return value, (grads,), hess +end + +function pushforward_function( + ab::AbstractBackend, + f, + xs..., +) + return (ds) -> begin + return jacobian(lowest(ab), (xds...,) -> begin + if ds isa Tuple + @assert length(xs) == length(ds) + newxs = xs .+ ds .* xds + return f(newxs...) + else + @assert length(xs) == length(xds) == 1 + newx = xs[1] + ds * xds[1] + return f(newx) + end + end, _zero.(xs, ds)...) + end +end +function value_and_pushforward_function( + ab::AbstractBackend, + f, + xs..., +) + return (ds) -> begin + if !(ds isa Tuple) + ds = (ds,) + end + @assert length(ds) == length(xs) + local value + primalcalled = false + if ab isa AbstractFiniteDifference + value = primalvalue(ab, nothing, f, xs) + primalcalled = true + end + pf = pushforward_function(lowest(ab), (_xs...,) -> begin + vs = f(_xs...) + if !primalcalled + value = primalvalue(lowest(ab), vs, f, xs) + primalcalled = true + end + return vs + end, xs...)(ds) + + return value, pf + end +end + +_zero(::Number, d::Number) = zero(d) +_zero(::Number, d::AbstractVector) = zero(d) +_zero(::AbstractVector, d::AbstractVector) = zero(eltype(d)) +_zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2))) +_zero(::AbstractMatrix, d::AbstractMatrix) = zero(d) +_zero(::Any, d::Any) = zero(d) + +function pullback_function(ab::AbstractBackend, f, xs...) + return (ws) -> begin + jacs = jacobian(lowest(ab), (xs...,) -> begin + vs = f(xs...) + if ws isa Tuple + @assert length(vs) == length(ws) + return sum(zip(vs, ws)) do v, w + if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector + return w' * v + else + # for arbitrary arrays + return dot(w, v) + end + end + else + w, v = ws, vs + if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector + return w' * v + else + # for arbitrary arrays + return dot(w, v) + end + end + end, xs...) + return adjoint.(jacs) + end +end +function value_and_pullback_function( + ab::AbstractBackend, + f, + xs..., +) + return (ws) -> begin + local value + primalcalled = false + if ab isa AbstractFiniteDifference + value = primalvalue(ab, nothing, f, xs) + primalcalled = true + end + if ws === nothing + vs = f(xs...) + if !primalcalled + value = primalvalue(lowest(ab), vs, f, xs) + primalcalled = true + end + return value, nothing + end + pb = pullback_function(lowest(ab), (_xs...,) -> begin + vs = f(_xs...) + if !primalcalled + value = primalvalue(lowest(ab), vs, f, xs) + primalcalled = true + end + return vs + end, xs...)(ws) + return value, pb + end +end + +struct LazyDerivative{B, F, X} + backend::B + f::F + xs::X +end + +function Base.:*(d::LazyDerivative, y) + return derivative(d.backend, d.f, d.xs...) * y +end + +function Base.:*(y, d::LazyDerivative) + return y * derivative(d.backend, d.f, d.xs...) +end + +function Base.:*(d::LazyDerivative, y::Union{Number,Tuple}) + if y isa Tuple && d.xs isa Tuple + @assert length(y) == length(d.xs) + end + return derivative(d.backend, d.f, d.xs...) .* y +end + +function Base.:*(y::Union{Number,Tuple}, d::LazyDerivative) + if y isa Tuple && d.xs isa Tuple + @assert length(y) == length(d.xs) + end + return y .* derivative(d.backend, d.f, d.xs...) +end + +function Base.:*(d::LazyDerivative, y::AbstractArray) + return map((d)-> d*y, derivative(d.backend, d.f, d.xs...)) +end + +function Base.:*(y::AbstractArray, d::LazyDerivative) + return map((d)-> y*d, derivative(d.backend, d.f, d.xs...)) +end + + +struct LazyGradient{B, F, X} + backend::B + f::F + xs::X +end +Base.:*(d::LazyGradient, y) = gradient(d.backend, d.f, d.xs...) * y +Base.:*(y, d::LazyGradient) = y * gradient(d.backend, d.f, d.xs...) + +function Base.:*(d::LazyGradient, y::Union{Number,Tuple}) + if y isa Tuple && d.xs isa Tuple + @assert length(y) == length(d.xs) + end + if d.xs isa Tuple + return gradient(d.backend, d.f, d.xs...) .* y + else + return gradient(d.backend, d.f, d.xs) .* y + end +end + +function Base.:*(y::Union{Number,Tuple}, d::LazyGradient) + if y isa Tuple && d.xs isa Tuple + @assert length(y) == length(d.xs) + end + if d.xs isa Tuple + return y .* gradient(d.backend, d.f, d.xs...) + else + return y .* gradient(d.backend, d.f, d.xs) + end +end + + +struct LazyJacobian{B, F, X} + backend::B + f::F + xs::X +end + +function Base.:*(d::LazyJacobian, ys) + if !(ys isa Tuple) + ys = (ys, ) + end + if d.xs isa Tuple + vjp = pushforward_function(d.backend, d.f, d.xs...)(ys) + else + vjp = pushforward_function(d.backend, d.f, d.xs)(ys) + end + if vjp isa Tuple + return vjp + else + return (vjp,) + end +end + +function Base.:*(ys, d::LazyJacobian) + if ys isa Tuple + ya = adjoint.(ys) + else + ya = adjoint(ys) + end + if d.xs isa Tuple + return pullback_function(d.backend, d.f, d.xs...)(ya) + else + return pullback_function(d.backend, d.f, d.xs)(ya) + end +end + +function Base.:*(d::LazyJacobian, ys::Number) + if d.xs isa Tuple + return jacobian(d.backend, d.f, d.xs...) .* ys + else + return jacobian(d.backend, d.f, d.xs) .* ys + end +end + +function Base.:*(ys::Number, d::LazyJacobian) + if d.xs isa Tuple + return jacobian(d.backend, d.f, d.xs...) .* ys + else + return jacobian(d.backend, d.f, d.xs) .* ys + end +end + + +struct LazyHessian{B, F, X} + backend::B + f::F + xs::X +end + +function Base.:*(d::LazyHessian, ys) + if !(ys isa Tuple) + ys = (ys, ) + end + + if d.xs isa Tuple + res = pushforward_function( + secondlowest(d.backend), + (xs...,) -> gradient(lowest(d.backend), d.f, xs...)[1], d.xs...,)(ys) # [1] because gradient returns a tuple + else + res = pushforward_function( + secondlowest(d.backend), + (xs,) -> gradient(lowest(d.backend), d.f, xs)[1],d.xs,)(ys) # gradient returns a tuple + end + if res isa Tuple + return res + else + return (res,) + end +end + +function Base.:*(ys, d::LazyHessian) + if ys isa Tuple + ya = adjoint.(ys) + else + ya = adjoint(ys) + end + if d.xs isa Tuple + return pullback_function( + secondlowest(d.backend), + (xs...,) -> gradient(lowest(d.backend), d.f, xs...), + d.xs..., + )(ya) + else + return pullback_function( + secondlowest(d.backend), + (xs,) -> gradient(lowest(d.backend), d.f, xs)[1], + d.xs, + )(ya) + end +end + +function Base.:*(d::LazyHessian, ys::Number) + if d.xs isa Tuple + return hessian(d.backend, d.f, d.xs...).*ys + else + return hessian(d.backend, d.f, d.xs).*ys + end +end + +function Base.:*(ys::Number, d::LazyHessian) + if d.xs isa Tuple + return ys.*hessian(d.backend, d.f, d.xs...) + else + return ys.*hessian(d.backend, d.f, d.xs) + end +end + + +function lazyderivative(ab::AbstractBackend, f, xs::Number...) + return LazyDerivative(ab, f, xs) +end +function lazygradient(ab::AbstractBackend, f, xs...) + return LazyGradient(ab, f, xs) +end +function lazyhessian(ab::AbstractBackend, f, xs...) + return LazyHessian(ab, f, xs) +end +function lazyjacobian(ab::AbstractBackend, f, xs...) + return LazyJacobian(ab, f, xs) +end + +struct D{B, F} + backend::B + f::F +end +D(b::AbstractBackend, d::D) = H(HigherOrderBackend((b, d.b)), d.f) +D(d::D) = H(HigherOrderBackend((d.backend, d.backend)), d.f) +function (d::D)(xs...; lazy = true) + if lazy + return lazyjacobian(d.ab, d.f, xs...) + else + return jacobian(d.ab, d.f, xs...) + end +end + +struct H{B, F} + backend::B + f::F +end +function (h::H)(xs...; lazy = true) + if lazy + return lazyhessian(h.ab, h.f, xs...) + else + return hessian(h.ab, h.f, xs...) + end +end + +macro primitive(expr) + fdef = ExprTools.splitdef(expr) + name = fdef[:name] + if name == :pushforward_function + return define_pushforward_function_and_friends(fdef) |> esc + elseif name == :pullback_function + return define_pullback_function_and_friends(fdef) |> esc + elseif name == :jacobian + return define_jacobian_and_friends(fdef) |> esc + elseif name == :primalvalue + return define_primalvalue(fdef) |> esc + else + throw("Unsupported AD primitive.") + end +end + +function define_pushforward_function_and_friends(fdef) + fdef[:name] = :(AbstractDifferentiation.pushforward_function) + args = fdef[:args] + funcs = quote + $(ExprTools.combinedef(fdef)) + function AbstractDifferentiation.jacobian($(args...),) + identity_like = AbstractDifferentiation.identity_matrix_like($(args[3:end]...),) + pff = AbstractDifferentiation.pushforward_function($(args...),) + if eltype(identity_like) <: Tuple{Vararg{Union{AbstractMatrix, Number}}} + return map(identity_like) do identity_like_i + return mapreduce(hcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) + pff(cols) + end + end + elseif eltype(identity_like) <: AbstractMatrix + # needed for the computation of the Hessian and Jacobian + ret = hcat.(mapslices(identity_like[1], dims=1) do cols + # cols loop over basis states + pf = pff((cols,)) + if typeof(pf) <: AbstractVector + # to make the hcat. work / get correct matrix-like, non-flat output dimension + return (pf, ) + else + return pf + end + end ...) + return ret isa Tuple ? ret : (ret,) + + else + return pff(identity_like) + end + end + end + return funcs +end + +function define_pullback_function_and_friends(fdef) + fdef[:name] = :(AbstractDifferentiation.pullback_function) + args = fdef[:args] + funcs = quote + $(ExprTools.combinedef(fdef)) + function AbstractDifferentiation.jacobian($(args...),) + value_and_pbf = AbstractDifferentiation.value_and_pullback_function($(args...),) + value, _ = value_and_pbf(nothing) + identity_like = AbstractDifferentiation.identity_matrix_like(value) + if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}} + return map(identity_like) do identity_like_i + return mapreduce(vcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) + value_and_pbf(cols)[2]' + end + end + elseif eltype(identity_like) <: AbstractMatrix + # needed for Hessian computation: + # value is a (grad,). Then, identity_like is a (matrix,). + # cols loops over columns of the matrix + return vcat.(mapslices(identity_like[1], dims=1) do cols + adjoint.(value_and_pbf((cols,))[2]) + end ...) + else + return adjoint.(value_and_pbf(identity_like)[2]) + end + end + end + return funcs +end + +_eachcol(a::Number) = (a,) +_eachcol(a) = eachcol(a) + +function define_jacobian_and_friends(fdef) + fdef[:name] = :(AbstractDifferentiation.jacobian) + return ExprTools.combinedef(fdef) +end + +function define_primalvalue(fdef) + fdef[:name] = :(AbstractDifferentiation.primalvalue) + return ExprTools.combinedef(fdef) +end + +function identity_matrix_like(x) + throw("The function `identity_matrix_like` is not defined for the type $(typeof(x)).") +end +function identity_matrix_like(x::AbstractVector) + return (Matrix{eltype(x)}(I, length(x), length(x)),) +end +function identity_matrix_like(x::Number) + return (one(x),) +end +identity_matrix_like(x::Tuple) = identity_matrix_like(x...) +@generated function identity_matrix_like(x...) + expr = :(()) + for i in 1:length(x) + push!(expr.args, :(())) + for j in 1:i-1 + push!(expr.args[i].args, :((zero_matrix_like(x[$j])[1]))) + end + push!(expr.args[i].args, :((identity_matrix_like(x[$i]))[1])) + for j in i+1:length(x) + push!(expr.args[i].args, :(zero_matrix_like(x[$j])[1])) + end + end + return expr +end + +zero_matrix_like(x::Tuple) = zero_matrix_like(x...) +zero_matrix_like(x...) = map(zero_matrix_like, x) +zero_matrix_like(x::AbstractVector) = (zero(similar(x, length(x), length(x))),) +zero_matrix_like(x::Number) = (zero(x),) +function zero_matrix_like(x) + throw("The function `zero_matrix_like` is not defined for the type $(typeof(x)).") +end end diff --git a/test/runtests.jl b/test/runtests.jl index a8a79fc..1bf305a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,637 @@ using AbstractDifferentiation -using Test +using Test, FiniteDifferences, LinearAlgebra +using ForwardDiff +using Zygote +using Random +Random.seed!(1234) + +const FDM = FiniteDifferences + +## FiniteDifferences +struct FDMBackend1{A} <: AD.AbstractFiniteDifference + alg::A +end +FDMBackend1() = FDMBackend1(central_fdm(5, 1)) +const fdm_backend1 = FDMBackend1() +# Minimal interface +AD.@primitive function jacobian(ab::FDMBackend1, f, xs...) + return FDM.jacobian(ab.alg, f, xs...) +end + +struct FDMBackend2{A} <: AD.AbstractFiniteDifference + alg::A +end +FDMBackend2() = FDMBackend2(central_fdm(5, 1)) +const fdm_backend2 = FDMBackend2() +AD.@primitive function pushforward_function(ab::FDMBackend2, f, xs...) + return function (vs) + FDM.jvp(ab.alg, f, tuple.(xs, vs)...) + end +end + +struct FDMBackend3{A} <: AD.AbstractFiniteDifference + alg::A +end +FDMBackend3() = FDMBackend3(central_fdm(5, 1)) +const fdm_backend3 = FDMBackend3() +AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...) + return function (vs) + # Supports only single output + if vs isa AbstractVector + return FDM.j′vp(ab.alg, f, vs, xs...) + else + @assert length(vs) == 1 + return FDM.j′vp(ab.alg, f, vs[1], xs...) + + end + end +end +## + + +## ForwardDiff +struct ForwardDiffBackend1 <: AD.AbstractForwardMode end +const forwarddiff_backend1 = ForwardDiffBackend1() +AD.@primitive function jacobian(ab::ForwardDiffBackend1, f, xs) + if xs isa Number + return (ForwardDiff.derivative(f, xs),) + elseif xs isa AbstractArray + out = f(xs) + if out isa Number + return (adjoint(ForwardDiff.gradient(f, xs)),) + else + return (ForwardDiff.jacobian(f, xs),) + end + elseif xs isa Tuple + error(typeof(xs)) + else + error(typeof(xs)) + end +end +AD.primalvalue(::ForwardDiffBackend1, ::Any, f, xs) = ForwardDiff.value.(f(xs...)) + +struct ForwardDiffBackend2 <: AD.AbstractForwardMode end +const forwarddiff_backend2 = ForwardDiffBackend2() +AD.@primitive function pushforward_function(ab::ForwardDiffBackend2, f, xs...) + # jvp = f'(x)*v, i.e., differentiate f(x + h*v) wrt h at 0 + return function (vs) + if xs isa Tuple + @assert length(xs) <= 2 + if length(xs) == 1 + (ForwardDiff.derivative(h->f(xs[1]+h*vs[1]),0),) + else + ForwardDiff.derivative(h->f(xs[1]+h*vs[1], xs[2]+h*vs[2]),0) + end + else + ForwardDiff.derivative(h->f(xs+h*vs),0) + end + end +end +AD.primalvalue(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs...)) +## + +## Zygote +struct ZygoteBackend1 <: AD.AbstractReverseMode end +const zygote_backend1 = ZygoteBackend1() +AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...) + return function (vs) + # Supports only single output + _, back = Zygote.pullback(f, xs...) + if vs isa AbstractVector + back(vs) + else + @assert length(vs) == 1 + back(vs[1]) + end + end +end +## + +fder(x, y) = exp(y) * x + y * log(x) +dfderdx(x, y) = exp(y) + y * 1/x +dfderdy(x, y) = exp(y) * x + log(x) + +fgrad(x, y) = prod(x) + sum(y ./ (1:length(y))) +dfgraddx(x, y) = prod(x)./x +dfgraddy(x, y) = one(eltype(y)) ./ (1:length(y)) +dfgraddxdx(x, y) = prod(x)./(x*x') - Diagonal(diag(prod(x)./(x*x'))) +dfgraddydy(x, y) = zeros(length(y),length(y)) + +function fjac(x, y) + x + -3*y + [y[2:end];zero(y[end])]/2# Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) * y +end +dfjacdx(x, y) = I(length(x)) +dfjacdy(x, y) = Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) + +# Jvp +jxvp(x,y,v) = dfjacdx(x,y)*v +jyvp(x,y,v) = dfjacdy(x,y)*v + +# vJp +vJxp(x,y,v) = dfjacdx(x,y)'*v +vJyp(x,y,v) = dfjacdy(x,y)'*v + +const xscalar = rand() +const yscalar = rand() + +const xvec = rand(5) +const yvec = rand(5) + +# to check if vectors get mutated +xvec2 = deepcopy(xvec) +yvec2 = deepcopy(yvec) + +function test_higher_order_backend(backends...) + ADbackends = AD.HigherOrderBackend(backends) + @test backends[end] == AD.lowest(ADbackends) + @test backends[end-1] == AD.secondlowest(ADbackends) + + for i in length(backends):-1:1 + @test backends[i] == AD.lowest(ADbackends) + ADbackends = AD.reduceorder(ADbackends) + end + backends[1] == AD.reduceorder(ADbackends) +end + +function test_derivatives(backend; multiple_inputs=true) + # test with respect to analytical solution + der_exact = (dfderdx(xscalar,yscalar), dfderdy(xscalar,yscalar)) + if multiple_inputs + der1 = AD.derivative(backend, fder, xscalar, yscalar) + @test minimum(isapprox.(der_exact, der1, rtol=1e-10)) + valscalar, der2 = AD.value_and_derivative(backend, fder, xscalar, yscalar) + @test valscalar == fder(xscalar, yscalar) + @test der2 .- der1 == (0, 0) + end + # test if single input (no tuple works) + valscalara, dera = AD.value_and_derivative(backend, x -> fder(x, yscalar), xscalar) + valscalarb, derb = AD.value_and_derivative(backend, y -> fder(xscalar, y), yscalar) + @test fder(xscalar, yscalar) == valscalara + @test fder(xscalar, yscalar) == valscalarb + @test isapprox(dera[1], der_exact[1], rtol=1e-10) + @test isapprox(derb[1], der_exact[2], rtol=1e-10) +end + +function test_gradients(backend; multiple_inputs=true) + # test with respect to analytical solution + grad_exact = (dfgraddx(xvec,yvec), dfgraddy(xvec,yvec)) + if multiple_inputs + grad1 = AD.gradient(backend, fgrad, xvec, yvec) + @test minimum(isapprox.(grad_exact, grad1, rtol=1e-10)) + valscalar, grad2 = AD.value_and_gradient(backend, fgrad, xvec, yvec) + @test valscalar == fgrad(xvec, yvec) + @test norm.(grad2 .- grad1) == (0, 0) + end + # test if single input (no tuple works) + valscalara, grada = AD.value_and_gradient(backend, x -> fgrad(x, yvec), xvec) + valscalarb, gradb = AD.value_and_gradient(backend, y -> fgrad(xvec, y), yvec) + @test fgrad(xvec, yvec) == valscalara + @test fgrad(xvec, yvec) == valscalarb + @test isapprox(grada[1], grad_exact[1], rtol=1e-10) + @test isapprox(gradb[1], grad_exact[2], rtol=1e-10) + @test xvec == xvec2 + @test yvec == yvec2 +end + +function test_jacobians(backend; multiple_inputs=true) + # test with respect to analytical solution + jac_exact = (dfjacdx(xvec, yvec), dfjacdy(xvec, yvec)) + if multiple_inputs + jac1 = AD.jacobian(backend, fjac, xvec, yvec) + @test minimum(isapprox.(jac_exact, jac1, rtol=1e-10)) + valvec, jac2 = AD.value_and_jacobian(backend, fjac, xvec, yvec) + @test valvec == fjac(xvec, yvec) + @test norm.(jac2 .- jac1) == (0, 0) + end + + # test if single input (no tuple works) + valveca, jaca = AD.value_and_jacobian(backend, x -> fjac(x, yvec), xvec) + valvecb, jacb = AD.value_and_jacobian(backend, y -> fjac(xvec, y), yvec) + @test fjac(xvec, yvec) == valveca + @test fjac(xvec, yvec) == valvecb + @test isapprox(jaca[1], jac_exact[1], rtol=1e-10) + @test isapprox(jacb[1], jac_exact[2], rtol=1e-10) + @test xvec == xvec2 + @test yvec == yvec2 +end + +function test_hessians(backend; multiple_inputs=false) + if multiple_inputs + # ... but + error("multiple_inputs=true is not supported.") + else + # explicit test that AbstractDifferentiation throws an error + # don't support tuple of Hessians + @test_throws AssertionError H1 = AD.hessian(backend, fgrad, (xvec, yvec)) + @test_throws MethodError H1 = AD.hessian(backend, fgrad, xvec, yvec) + end + + # @test dfgraddxdx(xvec,yvec) ≈ H1[1] atol=1e-10 + # @test dfgraddydy(xvec,yvec) ≈ H1[2] atol=1e-10 + + # test if single input (no tuple works) + fhess = x -> fgrad(x, yvec) + hess1 = AD.hessian(backend, fhess, xvec) + # test with respect to analytical solution + @test dfgraddxdx(xvec,yvec) ≈ hess1[1] atol=1e-10 + + valscalar, hess2 = AD.value_and_hessian(backend, fhess, xvec) + @test valscalar == fgrad(xvec, yvec) + @test norm.(hess2 .- hess1) == (0,) + valscalar, grad, hess3 = AD.value_gradient_and_hessian(backend, fhess, xvec) + @test valscalar == fgrad(xvec, yvec) + @test norm.(grad .- AD.gradient(backend, fhess, xvec)) == (0,) + @test norm.(hess3 .- hess1) == (0,) + + @test xvec == xvec2 + @test yvec == yvec2 + fhess2 = x-> dfgraddx(x, yvec) + hess4 = AD.jacobian(backend, fhess2, xvec) + @test minimum(isapprox.(hess4, hess1, atol=1e-10)) +end + +function test_jvp(backend; multiple_inputs=true) + v = (rand(length(xvec)), rand(length(yvec))) + + if multiple_inputs + if backend isa Union{FDMBackend2,ForwardDiffBackend2} # augmented version of v + identity_like = AD.identity_matrix_like(v) + vaug = map(identity_like) do identity_like_i + identity_like_i .* v + end + + pf1 = map(v->AD.pushforward_function(backend, fjac, xvec, yvec)(v), vaug) + ((valvec1, pf2x), (valvec2, pf2y)) = map(v->AD.value_and_pushforward_function(backend, fjac, xvec, yvec)(v), vaug) + else + pf1 = AD.pushforward_function(backend, fjac, xvec, yvec)(v) + valvec, pf2 = AD.value_and_pushforward_function(backend, fjac, xvec, yvec)(v) + ((valvec1, pf2x), (valvec2, pf2y)) = (valvec, pf2[1]), (valvec, pf2[2]) + end + + @test valvec1 == fjac(xvec, yvec) + @test valvec2 == fjac(xvec, yvec) + @test norm.((pf2x,pf2y) .- pf1) == (0, 0) + # test with respect to analytical solution + @test minimum(isapprox.(pf1, (jxvp(xvec,yvec,v[1]), jyvp(xvec,yvec,v[2])), atol=1e-10)) + @test xvec == xvec2 + @test yvec == yvec2 + end + + valvec1, pf1 = AD.value_and_pushforward_function(backend, x -> fjac(x, yvec), xvec)(v[1]) + valvec2, pf2 = AD.value_and_pushforward_function(backend, y -> fjac(xvec, y), yvec)(v[2]) + + if backend isa Union{FDMBackend2} + pf1 = (pf1,) + pf2 = (pf2,) + end + @test valvec1 == fjac(xvec, yvec) + @test valvec2 == fjac(xvec, yvec) + @test minimum(isapprox.((pf1[1],pf2[1]), (jxvp(xvec,yvec,v[1]), jyvp(xvec,yvec,v[2])), atol=1e-10)) +end + +function test_j′vp(backend; multiple_inputs=true) + # test with respect to analytical solution + w = rand(length(fjac(xvec, yvec))) + if multiple_inputs + pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w) + valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w) + @test valvec == fjac(xvec, yvec) + @test norm.(pb2 .- pb1) == (0, 0) + @test minimum(isapprox.(pb1, (vJxp(xvec,yvec,w), vJyp(xvec,yvec,w)), atol=1e-10)) + @test xvec == xvec2 + @test yvec == yvec2 + end + + valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w) + valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w) + @test valvec1 == fjac(xvec, yvec) + @test valvec2 == fjac(xvec, yvec) + @test minimum(isapprox.((pb1[1],pb2[1]), (vJxp(xvec,yvec,w), vJyp(xvec,yvec,w)), atol=1e-10)) +end + +function test_lazy_derivatives(backend; multiple_inputs=true) + # single input function + der1 = AD.derivative(backend, x->fder(x, yscalar), xscalar) + lazyder = AD.LazyDerivative(backend, x->fder(x, yscalar), xscalar) + + # multiplication with scalar + @test lazyder*yscalar == der1.*yscalar + @test lazyder*yscalar isa Tuple + + @test yscalar*lazyder == yscalar.*der1 + @test yscalar*lazyder isa Tuple + + # multiplication with array + @test lazyder*yvec == (der1.*yvec,) + @test lazyder*yvec isa Tuple + + @test yvec*lazyder == (yvec.*der1,) + @test yvec*lazyder isa Tuple + + # multiplication with tuple + @test lazyder*(yscalar,) == lazyder*yscalar + @test lazyder*(yvec,) == lazyder*yvec + + @test (yscalar,)*lazyder == yscalar*lazyder + @test (yvec,)*lazyder == yvec*lazyder + + # two input function + if multiple_inputs + der1 = AD.derivative(backend, fder, xscalar, yscalar) + lazyder = AD.LazyDerivative(backend, fder, (xscalar, yscalar)) + + # multiplication with scalar + @test lazyder*yscalar == der1.*yscalar + @test lazyder*yscalar isa Tuple + + @test yscalar*lazyder == yscalar.*der1 + @test yscalar*lazyder isa Tuple + + # multiplication with array + @test lazyder*yvec == (der1[1]*yvec, der1[2]*yvec) + @test lazyder*yvec isa Tuple + + @test yvec*lazyder == (yvec*der1[1], yvec*der1[2]) + @test lazyder*yvec isa Tuple + + # multiplication with tuple + @test_throws AssertionError lazyder*(yscalar,) + @test_throws AssertionError lazyder*(yvec,) + + @test_throws AssertionError (yscalar,)*lazyder + @test_throws AssertionError (yvec,)*lazyder + end +end + +function test_lazy_gradients(backend; multiple_inputs=true) + # single input function + grad1 = AD.gradient(backend, x->fgrad(x, yvec), xvec) + lazygrad = AD.LazyGradient(backend, x->fgrad(x, yvec), xvec) + + # multiplication with scalar + @test norm.(lazygrad*yscalar .- grad1.*yscalar) == (0,) + @test lazygrad*yscalar isa Tuple + + @test norm.(yscalar*lazygrad .- yscalar.*grad1) == (0,) + @test yscalar*lazygrad isa Tuple + + # multiplication with tuple + @test lazygrad*(yscalar,) == lazygrad*yscalar + @test (yscalar,)*lazygrad == yscalar*lazygrad + + # two input function + if multiple_inputs + grad1 = AD.gradient(backend, fgrad, xvec, yvec) + lazygrad = AD.LazyGradient(backend, fgrad, (xvec, yvec)) + + # multiplication with scalar + @test norm.(lazygrad*yscalar .- grad1.*yscalar) == (0,0) + @test lazygrad*yscalar isa Tuple + + @test norm.(yscalar*lazygrad .- yscalar.*grad1) == (0,0) + @test yscalar*lazygrad isa Tuple + + # multiplication with tuple + @test_throws AssertionError lazygrad*(yscalar,) == lazygrad*yscalar + @test_throws AssertionError (yscalar,)*lazygrad == yscalar*lazygrad + end +end + +function test_lazy_jacobians(backend; multiple_inputs=true) + # single input function + jac1 = AD.jacobian(backend, x->fjac(x, yvec), xvec) + lazyjac = AD.LazyJacobian(backend, x->fjac(x, yvec), xvec) + + # multiplication with scalar + @test norm.(lazyjac*yscalar .- jac1.*yscalar) == (0,) + @test lazyjac*yscalar isa Tuple + + @test norm.(yscalar*lazyjac .- yscalar.*jac1) == (0,) + @test yscalar*lazyjac isa Tuple + + w = rand(length(fjac(xvec, yvec))) + v = (rand(length(xvec)),rand(length(xvec))) + + # vjp + pb1 = (vJxp(xvec,yvec,w),) + res = w'*lazyjac + @test minimum(isapprox.(pb1, res, atol=1e-10)) + @test res isa Tuple + + # jvp + pf1 = (jxvp(xvec,yvec,v[1]),) + res = lazyjac*v[1] + @test minimum(isapprox.(pf1, res, atol=1e-10)) + @test res isa Tuple + + # two input function + if multiple_inputs + jac1 = AD.jacobian(backend, fjac, xvec, yvec) + lazyjac = AD.LazyJacobian(backend, fjac, (xvec, yvec)) + + # multiplication with scalar + @test norm.(lazyjac*yscalar .- jac1.*yscalar) == (0,0) + @test lazyjac*yscalar isa Tuple + + @test norm.(yscalar*lazyjac .- yscalar.*jac1) == (0,0) + @test yscalar*lazyjac isa Tuple + + # vjp + pb1 = (vJxp(xvec,yvec,w), vJyp(xvec,yvec,w)) + res = w'lazyjac + @test minimum(isapprox.(pb1, res, atol=1e-10)) + @test res isa Tuple + + # jvp + pf1 = (jxvp(xvec,yvec,v[1]), jyvp(xvec,yvec,v[2])) + + if backend isa Union{FDMBackend2,ForwardDiffBackend2} # augmented version of v + identity_like = AD.identity_matrix_like(v) + vaug = map(identity_like) do identity_like_i + identity_like_i .* v + end + + res = map(v->(lazyjac*v)[1], vaug) + else + res = lazyjac*v + end + @test minimum(isapprox.(pf1, res, atol=1e-10)) + @test res isa Tuple + end +end + +function test_lazy_hessians(backend; multiple_inputs=true) + # fdm_backend not used here yet.. + # single input function + fhess = x -> fgrad(x, yvec) + hess1 = (dfgraddxdx(xvec,yvec),) + lazyhess = AD.LazyHessian(backend, fhess, xvec) + + # multiplication with scalar + @test minimum(isapprox.(lazyhess*yscalar, hess1.*yscalar, atol=1e-10)) + @test lazyhess*yscalar isa Tuple + + # multiplication with scalar + @test minimum(isapprox.(yscalar*lazyhess, yscalar.*hess1, atol=1e-10)) + @test yscalar*lazyhess isa Tuple + + w = rand(length(xvec)) + v = rand(length(xvec)) + + # Hvp + Hv = map(h->h*v, hess1) + res = lazyhess*v + @test minimum(isapprox.(Hv, res, atol=1e-10)) + @test res isa Tuple + + # H′vp + wH = map(h->h'*w, hess1) + res = w'*lazyhess + @test minimum(isapprox.(wH, res, atol=1e-10)) + @test res isa Tuple +end @testset "AbstractDifferentiation.jl" begin - # Write your tests here. + @testset "Utils" begin + test_higher_order_backend(fdm_backend1, fdm_backend2, fdm_backend3, zygote_backend1, forwarddiff_backend2) + end + @testset "FiniteDifferences" begin + @testset "Derivative" begin + test_derivatives(fdm_backend1) + test_derivatives(fdm_backend2) + test_derivatives(fdm_backend3) + end + @testset "Gradient" begin + test_gradients(fdm_backend1) + test_gradients(fdm_backend2) + test_gradients(fdm_backend3) + end + @testset "Jacobian" begin + test_jacobians(fdm_backend1) + test_jacobians(fdm_backend2) + test_jacobians(fdm_backend3) + end + @testset "Hessian" begin + test_hessians(fdm_backend1) + test_hessians(fdm_backend2) + test_hessians(fdm_backend3) + end + @testset "jvp" begin + test_jvp(fdm_backend1) + test_jvp(fdm_backend2) + test_jvp(fdm_backend3) + end + @testset "j′vp" begin + test_j′vp(fdm_backend1) + test_j′vp(fdm_backend2) + test_j′vp(fdm_backend3) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(fdm_backend1) + test_lazy_derivatives(fdm_backend2) + test_lazy_derivatives(fdm_backend3) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(fdm_backend1) + test_lazy_gradients(fdm_backend2) + test_lazy_gradients(fdm_backend3) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(fdm_backend1) + test_lazy_jacobians(fdm_backend2) + test_lazy_jacobians(fdm_backend3) + end + @testset "Lazy Hessian" begin + test_lazy_hessians(fdm_backend1) + test_lazy_hessians(fdm_backend2) + test_lazy_hessians(fdm_backend3) + end + end + @testset "ForwardDiff" begin + @testset "Derivative" begin + test_derivatives(forwarddiff_backend1; multiple_inputs=false) + test_derivatives(forwarddiff_backend2) + end + @testset "Gradient" begin + test_gradients(forwarddiff_backend1; multiple_inputs=false) + test_gradients(forwarddiff_backend2) + end + @testset "Jacobian" begin + test_jacobians(forwarddiff_backend1; multiple_inputs=false) + test_jacobians(forwarddiff_backend2) + end + @testset "Hessian" begin + test_hessians(forwarddiff_backend1; multiple_inputs=false) + test_hessians(forwarddiff_backend2) + end + @testset "jvp" begin + test_jvp(forwarddiff_backend1; multiple_inputs=false) + test_jvp(forwarddiff_backend2) + end + @testset "j′vp" begin + test_j′vp(forwarddiff_backend1; multiple_inputs=false) + test_j′vp(forwarddiff_backend2) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(forwarddiff_backend1; multiple_inputs=false) + test_lazy_derivatives(forwarddiff_backend2) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(forwarddiff_backend1; multiple_inputs=false) + test_lazy_gradients(forwarddiff_backend2) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(forwarddiff_backend1; multiple_inputs=false) + test_lazy_jacobians(forwarddiff_backend2) + end + @testset "Lazy Hessian" begin + test_lazy_hessians(forwarddiff_backend1; multiple_inputs=false) + test_lazy_hessians(forwarddiff_backend2) + end + end + @testset "Zygote" begin + @testset "Derivative" begin + test_derivatives(zygote_backend1) + end + @testset "Gradient" begin + test_gradients(zygote_backend1) + end + @testset "Jacobian" begin + test_jacobians(zygote_backend1) + end + @testset "Hessian" begin + # Zygote over Zygote problems + backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1)) + test_hessians(backends) + backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) + test_hessians(backends) + # fails: + # backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend2)) + # test_hessians(backends) + end + @testset "jvp" begin + test_jvp(zygote_backend1) + end + @testset "j′vp" begin + test_j′vp(zygote_backend1) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(zygote_backend1) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(zygote_backend1) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(zygote_backend1) + end + @testset "Lazy Hessian" begin + # Zygote over Zygote problems + backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1)) + test_lazy_hessians(backends) + backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) + test_lazy_hessians(backends) + end + end end + +