From 14068fe7e0f44624bd9562418ba205f4b6c46f50 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Tue, 3 Sep 2024 09:36:37 -0400 Subject: [PATCH] Make grad_J_a act not-in-place This is much easier for a user to deal with. It really shouldn't make any difference for performance, and in fact, may enhance it (the old implementation was using `copyto!`). In situations where it would make a difference, like for insanely large pulses, there's always the option of using a non-allocating functor. --- ext/QuantumControlFiniteDifferencesExt.jl | 6 ++--- ext/QuantumControlZygoteExt.jl | 6 ++--- src/functionals.jl | 30 ++++++++++++----------- test/test_functionals.jl | 23 ++++++++--------- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/ext/QuantumControlFiniteDifferencesExt.jl b/ext/QuantumControlFiniteDifferencesExt.jl index 248412b..3c97e69 100644 --- a/ext/QuantumControlFiniteDifferencesExt.jl +++ b/ext/QuantumControlFiniteDifferencesExt.jl @@ -72,13 +72,13 @@ end function make_automatic_grad_J_a(J_a, tlist, ::Val{:FiniteDifferences}) - function automatic_grad_J_a!(∇J_a, pulsevals, tlist) + function automatic_grad_J_a(pulsevals, tlist) func = pulsevals -> J_a(pulsevals, tlist) fdm = FiniteDifferences.central_fdm(5, 1) ∇J_a_fdm = FiniteDifferences.grad(fdm, func, pulsevals)[1] - copyto!(∇J_a, ∇J_a_fdm) + return ∇J_a_fdm end - return automatic_grad_J_a! + return automatic_grad_J_a end function make_gate_chi(J_T_U, trajectories, ::Val{:FiniteDifferences}; kwargs...) diff --git a/ext/QuantumControlZygoteExt.jl b/ext/QuantumControlZygoteExt.jl index 8cf24de..858d28a 100644 --- a/ext/QuantumControlZygoteExt.jl +++ b/ext/QuantumControlZygoteExt.jl @@ -73,12 +73,12 @@ end function make_automatic_grad_J_a(J_a, tlist, ::Val{:Zygote}) - function automatic_grad_J_a!(∇J_a, pulsevals, tlist) + function automatic_grad_J_a(pulsevals, tlist) func = pulsevals -> J_a(pulsevals, tlist) ∇J_a_zygote = Zygote.gradient(func, pulsevals)[1] - copyto!(∇J_a, ∇J_a_zygote) + return ∇J_a_zygote end - return automatic_grad_J_a! + return automatic_grad_J_a end diff --git a/src/functionals.jl b/src/functionals.jl index 0850391..4912d31 100644 --- a/src/functionals.jl +++ b/src/functionals.jl @@ -317,7 +317,7 @@ end Return a function to evaluate ``∂J_a/∂ϵ_{ln}`` for a pulse value running cost. ```julia -grad_J_a! = make_grad_J_a( +grad_J_a = make_grad_J_a( J_a, tlist; mode=:any, @@ -325,10 +325,10 @@ grad_J_a! = make_grad_J_a( ) ``` -returns a function so that `grad_J_a!(∇J_a, pulsevals, tlist)` sets -``∂J_a/∂ϵ_{ln}`` as the elements of the (vectorized) `∇J_a`. The function `J_a` -must have the interface `J_a(pulsevals, tlist)`, see, e.g., -`J_a_fluence`. +returns a function so that `∇J_a = grad_J_a(pulsevals, tlist)` sets +that retrurns a vector `∇J_a` containing the vectorized elements +``∂J_a/∂ϵ_{ln}``. The function `J_a` must have the interface `J_a(pulsevals, +tlist)`, see, e.g., [`J_a_fluence`](@ref). The parameters `mode` and `automatic` are handled as in [`make_chi`](@ref), where `mode` is one of `:any`, `:analytic`, `:automatic`, and `automatic` is @@ -341,10 +341,11 @@ refers to the framework set with `QuantumControl.set_default_ad_framework`. new `J_a` function, define a new method `make_analytic_grad_J_a` like so: ```julia - make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence! + make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence ``` - which links `make_grad_J_a` for `J_a_fluence` to `grad_J_a_fluence!`. + which links `make_grad_J_a` for [`J_a_fluence`](@ref) to + [`grad_J_a_fluence`](@ref). """ function make_grad_J_a(J_a, tlist; mode=:any, automatic=:default) if mode == :any @@ -890,19 +891,20 @@ end """Analytic derivative for [`J_a_fluence`](@ref). ```julia -grad_J_a_fluence!(∇J_a, pulsevals, tlist) +∇J_a = grad_J_a_fluence(pulsevals, tlist) ``` -sets the (vectorized) elements of `∇J_a` to ``2 ϵ_{nl} dt``, where -``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the time -step, taken from the first time interval of `tlist` and assumed to be uniform. +returns the `∇J_a`, which contains the (vectorized) elements ``2 ϵ_{nl} dt``, +where ``ϵ_{nl}`` are the (vectorized) elements of `pulsevals` and ``dt`` is the +time step, taken from the first time interval of `tlist` and assumed to be +uniform. """ -function grad_J_a_fluence!(∇J_a, pulsevals, tlist) +function grad_J_a_fluence(pulsevals, tlist) dt = tlist[begin+1] - tlist[begin] - axpy!(2 * dt, pulsevals, ∇J_a) + return (2 * dt) * pulsevals end -make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence! +make_analytic_grad_J_a(::typeof(J_a_fluence), tlist) = grad_J_a_fluence end diff --git a/test/test_functionals.jl b/test/test_functionals.jl index 9b06829..8675bed 100644 --- a/test/test_functionals.jl +++ b/test/test_functionals.jl @@ -6,7 +6,7 @@ using QuantumControl.Functionals: J_T_re, J_T_ss, J_a_fluence, - grad_J_a_fluence!, + grad_J_a_fluence, make_grad_J_a, make_chi, chi_re, @@ -101,20 +101,17 @@ end J_a_val = J_a_fluence(pulsevals, tlist) @test J_a_val > 0.0 - G1 = copy(wrk.grad_J_a) - grad_J_a_fluence!(G1, pulsevals, tlist) + G1 = grad_J_a_fluence(pulsevals, tlist) - grad_J_a_zygote! = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote) - @test grad_J_a_zygote! ≢ grad_J_a_fluence! - G2 = copy(wrk.grad_J_a) - grad_J_a_zygote!(G2, pulsevals, tlist) + grad_J_a_zygote = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=Zygote) + @test grad_J_a_zygote ≢ grad_J_a_fluence + G2 = grad_J_a_zygote(pulsevals, tlist) - grad_J_a_fdm! = + grad_J_a_fdm = make_grad_J_a(J_a_fluence, tlist; mode=:automatic, automatic=FiniteDifferences) - @test grad_J_a_fdm! ≢ grad_J_a_fluence! - @test grad_J_a_fdm! ≢ grad_J_a_zygote! - G3 = copy(wrk.grad_J_a) - grad_J_a_fdm!(G3, pulsevals, tlist) + @test grad_J_a_fdm ≢ grad_J_a_fluence + @test grad_J_a_fdm ≢ grad_J_a_zygote + G3 = grad_J_a_fdm(pulsevals, tlist) @test 0.0 ≤ norm(G2 - G1) < 1e-12 # zygote can be exact @test 0.0 < norm(G3 - G1) < 1e-12 # fdm should not be exact @@ -324,7 +321,7 @@ end end grad_J_a = capture.value @test_throws DomainError begin - grad_J_a(1, 1, tlist) + grad_J_a(1, tlist) end QuantumControl.set_default_ad_framework(nothing; quiet=true)