diff --git a/Manifest.toml b/Manifest.toml index 23989dc9..2b8a9b09 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,5 +1,11 @@ # This file is machine-generated - editing it directly is not advised +[[ArrayInterface]] +deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"] +git-tree-sha1 = "ee07ae00e3cc277dcfa5507ce25be522313ecc3e" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.1" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -11,6 +17,11 @@ repo-url = "https://github.com/FluxML/Functors.jl" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" version = "0.1.0" +[[IfElse]] +git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.0" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" diff --git a/Project.toml b/Project.toml index 62b70e6f..5b95afd1 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Mike J Innes <mike.j.innes@gmail.com>"] version = "0.1.0" [deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -12,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Functors = "0" Requires = "0.5, 1" +ArrayInterface = "3" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/Optimisers.jl b/src/Optimisers.jl index c57aa5a7..61db7765 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,9 +1,11 @@ module Optimisers using Functors: functor, fmap, isleaf +using ArrayInterface include("interface.jl") include("rules.jl") +include("mutating.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief, diff --git a/src/interface.jl b/src/interface.jl index b59dad53..669fca6a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,7 +10,7 @@ function state(o, x) end function _update(o, x, x̄, st) - x̄, st = apply(o, x, x̄, st) + x̄, st = ismutable(x) ? apply!(o, x, x̄, st) : apply(o, x, x̄, st) return patch(x, x̄), st end diff --git a/src/mutating.jl b/src/mutating.jl new file mode 100644 index 00000000..d8b5106b --- /dev/null +++ b/src/mutating.jl @@ -0,0 +1,169 @@ +# This file contains duplicated rules as in rules.jl +# where all the operations are done in-place for a softer deprecation + +function apply!(o::Descent, x, dx, state) + η = convert(eltype(dx), o.eta) + dx .*= η + + return dx, state +end + +function apply!(o::Momentum, x, dx, state) + η, ρ, v = o.eta, o.rho, state + @. v = ρ * v - η * dx + + return -v, v +end + +function apply!(o::Nesterov, x, dx, state) + η, ρ, v = o.eta, o.rho, state + @. d = ρ^2 * v - (1+ρ) * η * dx + @. v = ρ * v - η * dx + + return -d, v +end + +function apply!(o::RMSProp, x, dx, state) + η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state + @. acc = ρ * acc + (1 - ρ) * dx^2 + @. dx = dx * (η / (sqrt(acc) + ϵ)) + + return dx, acc +end + +function apply!(o::ADAM{T}, x, dx, state) where T + η, β, ϵ = o.eta, o.beta, o.epsilon + mt, vt, βt = state + + @. mt = β[1] * mt + (one(T) - β[1]) * dx + @. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2 + @. dx = mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η + + return dx, (mt, vt, βt .* β) +end + +function apply!(o::RADAM, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + ρ∞ = 2/(1-β[2])-1 + + mt, vt, βt, t = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * dx^2 + ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2]) + if ρ > 4 + r = sqrt((ρ - 4) * (ρ - 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ)) + @. dx = mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r + else + @. dx = mt / (1 - βt[1]) * η + end + + return dx, (mt, vt, βt .* β, t + 1) +end + +function apply!(o::AdaMax, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + + mt, ut, βt = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. ut = max(β[2] * ut, abs(dx)) + @. dx = (η/(1 - βt[1])) * mt/(ut + ϵ) + + return dx, (mt, ut, βt .* β) +end + +function apply!(o::AdaMax, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + + mt, ut, βt = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. ut = max(β[2] * ut, abs(dx)) + @. dx = (η/(1 - βt[1])) * mt/(ut + ϵ) + + return dx, (mt, ut, βt .* β) +end + +function apply!(o::OADAM, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + + mt, vt, βt, dx_ = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * dx^2 + @. dx = -dx_ + @. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) + @. dx += 2*dx_ + + return dx, (mt, vt, βt .* β, dx_) +end + +function apply!(o::ADAGrad, x, dx, state) + η, ϵ = o.eta, o.epsilon + acc, = state + + @. acc += dx^2 + @. dx *= η / (sqrt(acc) + ϵ) + + return dx, (acc,) +end + +function apply!(o::ADADelta, x, dx, state) + ρ, ϵ = o.rho, o.epsilon + acc, Δacc = state + + @. acc = ρ * acc + (1 - ρ) * dx^2 + # DON'T remove epsilon from numerator + # or even out of the square roots + @. dx *= sqrt(Δacc + ϵ) / sqrt(acc + ϵ) + @. Δacc = ρ * Δacc + (1 - ρ) * dx^2 + + return dx, (acc, Δacc) +end + +function apply!(o::AMSGrad, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + + mt, vt, v̂t = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * dx ^ 2 + @. v̂t = max(v̂t, vt) + @. dx = η * mt / (sqrt(v̂t) + ϵ) + + return dx, (mt, vt, v̂t) +end + +function apply!(o::NADAM, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + + mt, vt, βt = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. vt = β[2] * vt + (1 - β[2]) * dx^2 + @. dx = (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) / + (sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η + + return dx, (mt, vt, βt .* β) +end + +function apply!(o::AdaBelief, x, dx, state) + η, β, ϵ = o.eta, o.beta, o.epsilon + mt, st = state + + @. mt = β[1] * mt + (1 - β[1]) * dx + @. st = β[2] * st + (1 - β[2]) * (dx - mt)^2 + @. dx = η * mt / (sqrt(st) + ϵ) + + return dx, (mt, st) +end + +function apply!(o::OptimiserChain, x, dx, states) + new_states = similar(states) + for (i, (opt, state)) in enumerate(zip(o.opts, states)) + _, new_states[i] = apply!(opt, x, dx, state) + end + + return dx, new_states +end diff --git a/src/rules.jl b/src/rules.jl index 9c89ea28..d52c7596 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -17,7 +17,6 @@ init(o::Descent, x::AbstractArray) = nothing function apply(o::Descent, x, dx, state) η = convert(eltype(dx), o.eta) - dx .*= η return dx .* η, state end @@ -513,4 +512,4 @@ function apply(o::OptimiserChain, x, dx, states) end return dx, new_states -end \ No newline at end of file +end