Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Bring mutating optimisations back #13

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This file is machine-generated - editing it directly is not advised

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "ee07ae00e3cc277dcfa5507ce25be522313ecc3e"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand All @@ -11,6 +17,11 @@ repo-url = "https://github.com/FluxML/Functors.jl"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.1.0"

[[IfElse]]
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.1.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -12,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
Functors = "0"
Requires = "0.5, 1"
ArrayInterface = "3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 2 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module Optimisers

using Functors: functor, fmap, isleaf
using ArrayInterface

include("interface.jl")
include("rules.jl")
include("mutating.jl")

export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function state(o, x)
end

function _update(o, x, x̄, st)
x̄, st = apply(o, x, x̄, st)
x̄, st = ismutable(x) ? apply!(o, x, x̄, st) : apply(o, x, x̄, st)
return patch(x, x̄), st
end

Expand Down
160 changes: 160 additions & 0 deletions src/mutating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This file contains duplicated rules as in rules.jl
# where all the operations are done in-place for a softer deprecation

function apply!(o::Descent, x, dx, state)
η = convert(eltype(dx), o.eta)
dx .*= η

return dx, state
end

function apply!(o::Momentum, x, dx, state)
η, ρ, v = o.eta, o.rho, state
@. v = ρ * v - η * dx

return -v, v
end

function apply(o::Nesterov, x, dx, state)
η, ρ, v = o.eta, o.rho, state
@. d = ρ^2 * v - (1+ρ) * η * dx
@. v = ρ * v - η * dx

return -d, v
end

function apply(o::RMSProp, x, dx, state)
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
@. acc = ρ * acc + (1 - ρ) * dx^2
@. dx = dx * (η / (sqrt(acc) + ϵ))

return dx, acc
end

function apply(o::ADAM{T}, x, dx, state) where T
η, β, ϵ = o.eta, o.beta, o.epsilon
mt, vt, βt = state

@. mt = β[1] * mt + (one(T) - β[1]) * dx
@. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
@. dx = mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η

return dx, (mt, vt, βt .* β)
end

function apply(o::RADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon
ρ∞ = 2/(1-β[2])-1

mt, vt, βt, t = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. vt = β[2] * vt + (1 - β[2]) * dx^2
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
if ρ > 4
r = sqrt((ρ - 4) * (ρ - 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
@. dx = mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
else
@. dx = mt / (1 - βt[1]) * η
end

return dx, (mt, vt, βt .* β, t + 1)
end

function apply!(o::AdaMax, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

(mt, ut), βt = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. ut = max(β[2] * ut, abs(dx))
@. dx = (η/(1 - βt[1])) * mt/(ut + ϵ)

return dx, ((mt, ut), βt .* β)
end

function apply!(o::AdaMax, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

(mt, ut), βt = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. ut = max(β[2] * ut, abs(dx))
@. dx = (η/(1 - βt[1])) * mt/(ut + ϵ)

return dx, ((mt, ut), βt .* β)
end

function apply!(o::OADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

(mt, vt), βt, dx_ = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. vt = β[2] * vt + (1 - β[2]) * dx^2
@. dx = -dx_
@. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
@. dx += 2*dx_

return dx, ((mt, vt), βt .* β, dx_)
end

function apply!(o::ADAGrad, x, dx, state)
η, ϵ = o.eta, o.epsilon
acc, = state

@. acc += dx^2
@. dx *= η / (sqrt(acc) + ϵ)

return dx, (acc,)
end

function apply!(o::ADADelta, x, dx, state)
ρ, ϵ = o.rho, o.epsilon
acc, Δacc = state

@. acc = ρ * acc + (1 - ρ) * dx^2
# DON'T remove epsilon from numerator
# or even out of the square roots
@. dx *= sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * dx^2

return dx, (acc, Δacc)
end

function apply!(o::AMSGrad, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

mt, vt, v̂t = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
@. v̂t = max(v̂t, vt)
@. dx = η * mt / (sqrt(v̂t) + ϵ)

return dx, (mt, vt, v̂t)
end

function apply!(o::NADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

(mt, vt), βt = state.decays

@. mt = β[1] * mt + (1 - β[1]) * dx
@. vt = β[2] * vt + (1 - β[2]) * dx^2
@. dx = (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η

return dx, ((mt, vt), βt .* β)
end

function apply!(o::AdaBelief, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon
mt, st = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
@. dx = η * mt / (sqrt(st) + ϵ)

return dx, (mt, st)
end
3 changes: 1 addition & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ init(o::Descent, x::AbstractArray) = nothing

function apply(o::Descent, x, dx, state)
η = convert(eltype(dx), o.eta)
dx .*= η

return dx .* η, state
end
Expand Down Expand Up @@ -513,4 +512,4 @@ function apply(o::OptimiserChain, x, dx, states)
end

return dx, new_states
end
end