Skip to content

Commit

Permalink
Make grad_J_a act not-in-place
Browse files Browse the repository at this point in the history
This is much easier for a user to deal with. It really shouldn't make
any difference for performance, and in fact, may enhance it (the old
implementation was using `copyto!`). In situations where it would make a
difference, like for insanely large pulses, there's always the option of
using a non-allocating functor.
  • Loading branch information
goerz committed Sep 3, 2024
1 parent c85edec commit 14068fe
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
6 changes: 3 additions & 3 deletions ext/QuantumControlFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ end


function make_automatic_grad_J_a(J_a, tlist, ::Val{:FiniteDifferences})
function automatic_grad_J_a!(∇J_a, pulsevals, tlist)
function automatic_grad_J_a(pulsevals, tlist)
func = pulsevals -> J_a(pulsevals, tlist)
fdm = FiniteDifferences.central_fdm(5, 1)
∇J_a_fdm = FiniteDifferences.grad(fdm, func, pulsevals)[1]
copyto!(∇J_a, ∇J_a_fdm)
return ∇J_a_fdm
end
return automatic_grad_J_a!
return automatic_grad_J_a
end

function make_gate_chi(J_T_U, trajectories, ::Val{:FiniteDifferences}; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions ext/QuantumControlZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ end


function make_automatic_grad_J_a(J_a, tlist, ::Val{:Zygote})
function automatic_grad_J_a!(∇J_a, pulsevals, tlist)
function automatic_grad_J_a(pulsevals, tlist)
func = pulsevals -> J_a(pulsevals, tlist)
∇J_a_zygote = Zygote.gradient(func, pulsevals)[1]
copyto!(∇J_a, ∇J_a_zygote)
return ∇J_a_zygote
end
return automatic_grad_J_a!
return automatic_grad_J_a
end


Expand Down
30 changes: 16 additions & 14 deletions src/functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,18 +317,18 @@ end
Return a function to evaluate ``∂J_a/∂ϵ_{ln}`` for a pulse value running cost.
```julia
grad_J_a! = make_grad_J_a(
grad_J_a = make_grad_J_a(
J_a,
tlist;
mode=:any,
automatic=:default,
)
```
returns a function so that `grad_J_a!(∇J_a, pulsevals, tlist)` sets
``∂J_a/∂ϵ_{ln}`` as the elements of the (vectorized) `∇J_a`. The function `J_a`
must have the interface `J_a(pulsevals, tlist)`, see, e.g.,
`J_a_fluence`.
returns a function so that `∇J_a = grad_J_a(pulsevals, tlist)` sets
that retrurns a vector `∇J_a` containing the vectorized elements
``∂J_a/∂ϵ_{ln}``. The function `J_a` must have the interface `J_a(pulsevals,
tlist)`, see, e.g., [`J_a_fluence`](@ref).
The parameters `mode` and `automatic` are handled as in [`make_chi`](@ref),
where `mode` is one of `:any`, `:analytic`, `:automatic`, and `automatic` is
Expand All @@ -341,10 +341,11 @@ refers to the framework set with `QuantumControl.set_default_ad_framework`.
new `J_a` function, define a new method `make_analytic_grad_J_a` like so:
```julia
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence!
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence
```
which links `make_grad_J_a` for `J_a_fluence` to `grad_J_a_fluence!`.
which links `make_grad_J_a` for [`J_a_fluence`](@ref) to
[`grad_J_a_fluence`](@ref).
"""
function make_grad_J_a(J_a, tlist; mode=:any, automatic=:default)
if mode == :any
Expand Down Expand Up @@ -890,19 +891,20 @@ end
"""Analytic derivative for [`J_a_fluence`](@ref).
```julia
grad_J_a_fluence!(∇J_a, pulsevals, tlist)
∇J_a = grad_J_a_fluence(pulsevals, tlist)
```
sets the (vectorized) elements of `∇J_a` to ``2 ϵ_{nl} dt``, where
``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the time
step, taken from the first time interval of `tlist` and assumed to be uniform.
returns the `∇J_a`, which contains the (vectorized) elements ``2 ϵ_{nl} dt``,
where ``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the
time step, taken from the first time interval of `tlist` and assumed to be
uniform.
"""
function grad_J_a_fluence!(∇J_a, pulsevals, tlist)
function grad_J_a_fluence(pulsevals, tlist)
dt = tlist[begin+1] - tlist[begin]
axpy!(2 * dt, pulsevals, ∇J_a)
return (2 * dt) * pulsevals
end


make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence!
make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence

end
23 changes: 10 additions & 13 deletions test/test_functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using QuantumControl.Functionals:
J_T_re,
J_T_ss,
J_a_fluence,
grad_J_a_fluence!,
grad_J_a_fluence,
make_grad_J_a,
make_chi,
chi_re,
Expand Down Expand Up @@ -101,20 +101,17 @@ end
J_a_val = J_a_fluence(pulsevals, tlist)
@test J_a_val > 0.0

G1 = copy(wrk.grad_J_a)
grad_J_a_fluence!(G1, pulsevals, tlist)
G1 = grad_J_a_fluence(pulsevals, tlist)

grad_J_a_zygote! = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote)
@test grad_J_a_zygote! grad_J_a_fluence!
G2 = copy(wrk.grad_J_a)
grad_J_a_zygote!(G2, pulsevals, tlist)
grad_J_a_zygote = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote)
@test grad_J_a_zygote grad_J_a_fluence
G2 = grad_J_a_zygote(pulsevals, tlist)

grad_J_a_fdm! =
grad_J_a_fdm =
make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=FiniteDifferences)
@test grad_J_a_fdm! grad_J_a_fluence!
@test grad_J_a_fdm! grad_J_a_zygote!
G3 = copy(wrk.grad_J_a)
grad_J_a_fdm!(G3, pulsevals, tlist)
@test grad_J_a_fdm grad_J_a_fluence
@test grad_J_a_fdm grad_J_a_zygote
G3 = grad_J_a_fdm(pulsevals, tlist)

@test 0.0 norm(G2 - G1) < 1e-12 # zygote can be exact
@test 0.0 < norm(G3 - G1) < 1e-12 # fdm should not be exact
Expand Down Expand Up @@ -324,7 +321,7 @@ end
end
grad_J_a = capture.value
@test_throws DomainError begin
grad_J_a(1, 1, tlist)
grad_J_a(1, tlist)
end

QuantumControl.set_default_ad_framework(nothing; quiet=true)
Expand Down

0 comments on commit 14068fe

Please # to comment.