diff --git a/Manifest.toml b/Manifest.toml index b5496ecad7..6bf05007af 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -14,9 +14,9 @@ version = "0.3.3" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "27edd95a09fd428113ca019c092e8aeca2eb1f2d" +git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.0.0" +version = "3.2.0" [[Artifacts]] deps = ["Pkg"] @@ -40,21 +40,21 @@ version = "0.4.1" [[CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "39f6f584bec264ace76f924d1c8637c85617697e" +git-tree-sha1 = "6ccc73b2d8b671f7a65c92b5f08f81422ebb7547" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "2.4.0" +version = "2.4.1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "31b28f5123afa5e5ca0c885e4051896032754578" +git-tree-sha1 = "56bbb956a573ac16b277008edb1762ef80076e78" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.45" +version = "0.7.50" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] -git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a" +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "d3d0a4e0d5bc03a6c97f4d249c8a471fc20a2f33" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.24" +version = "0.9.28" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -93,15 +93,15 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "0.3.4+0" [[DataAPI]] -git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84" +git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.4.0" +version = "1.5.1" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "fb0aa371da91c1ff9dc7fbed6122d3e411420b9c" +git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.8" +version = "0.18.9" [[Dates]] deps = ["Printf"] @@ -134,9 +134,9 @@ version = "0.1.3" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "ff537e5a3cba92fb48f30fec46723510450f2c0e" +git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.10.2" +version = "0.11.1" [[FixedPointNumbers]] deps = ["Statistics"] @@ -146,15 +146,17 @@ version = "0.8.4" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4" +git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.14" +version = "0.10.16" [[Functors]] deps = ["MacroTools"] -git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a" +git-tree-sha1 = "60dc8972fec14145524caf17edfee222ab531e37" +repo-rev = "dg/grad" +repo-url = "https://github.com/FluxML/Functors.jl.git" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.1.0" +version = "0.2.0" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] @@ -191,9 +193,9 @@ version = "0.8.4" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "d0d99629d6ae4a3e211ae83d8870907bd842c811" +git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.5.2" +version = "3.6.0" [[LibGit2]] deps = ["Printf"] @@ -227,23 +229,18 @@ version = "0.5.0" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" +git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" +version = "0.4.5" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - [[NNlib]] -deps = ["ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "13fd29731c7f609cb82a3a544c5538584d22c153" +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.11" +version = "0.7.14" [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" @@ -256,10 +253,18 @@ git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.3+4" +[[Optimisers]] +deps = ["Functors", "Random", "Requires", "Statistics"] +git-tree-sha1 = "0a9f8b051708c9a41ead0810cf6aff9af2edb8bc" +repo-rev = "master" +repo-url = "https://github.com/FluxML/Optimisers.jl.git" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.1.0" + [[OrderedCollections]] -git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" +git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.2" +version = "1.3.3" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -282,10 +287,9 @@ deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" +version = "1.0.0" [[Requires]] deps = ["UUIDs"] @@ -381,9 +385,9 @@ version = "1.2.11+18" [[Zygote]] deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "d88a7e71fc2eef9510187b1c7d4af7a5177633d0" +git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.0" +version = "0.6.3" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/Project.toml b/Project.toml index feaaf39d8d..8e5b56316f 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -31,7 +32,7 @@ Adapt = "2.0, 3.0" CUDA = "2.1" CodecZlib = "0.7" Colors = "0.12" -Functors = "0.1" +Functors = "0.1, 0.2" Juno = "0.8" MacroTools = "0.5" NNlib = "0.7.10" diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 53bb3f4345..21fe07e2cd 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down. opt = Descent(0.1) # Gradient descent with learning rate 0.1 for p in (W, b) - update!(opt, p, grads[p]) + opt(p, grads[p]) end ``` @@ -70,23 +70,24 @@ Flux's optimisers are built around a `struct` that holds all the optimiser param In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example. ```julia -mutable struct Momentum +struct Momentum eta rho - velocity end -Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict()) +Momentum(eta::Real, rho::Real) = Momentum(eta, rho) +Optimisers.init(opt::Momentum, x) = (zero(x),) ``` The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked. ```julia -function Flux.Optimise.apply!(o::Momentum, x, Δ) +function Flux.Optimise.apply(o::Momentum, x, Δ, st) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v + v, = st + v = @. ρ * v - η * Δ + Δ = @. -v + Δ, (v,) end ``` @@ -97,9 +98,9 @@ v = ρ * v - η * Δ w = w - v ``` -The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser. +The `apply` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients and optimizer state. -Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. +Flux internally calls on this function via the `update!` function. It shares the API with `apply` but ensures that multiple parameters are handled gracefully. ## Composing Optimisers @@ -121,11 +122,12 @@ ps = Params([w, w1]) loss(x) = Flux.Losses.mse(w * x, w1 * x) loss(rand(10)) # around 9 +st = Optimisers.init(opt, [w, w1]) for t = 1:10^5 - θ = Params([w, w1]) + θ = ps θ̄ = gradient(() -> loss(rand(10)), θ) - Flux.Optimise.update!(opt, θ, θ̄) + ps, st = Flux.Optimise.update!(opt, θ, θ̄, st) end loss(rand(10)) # around 0.9 diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e2485a05d0..c0d56dc20e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,14 +1,14 @@ module Optimise using LinearAlgebra +using Optimisers +using Optimisers: apply -export train!, update!, - Descent, ADAM, Momentum, Nesterov, RMSProp, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, - ClipValue, ClipNorm +export train!, + Descent, ADAM, Momentum, Nesterov, RMSProp, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, + WeightDecay, stop, skip, ChainOptimiser -include("optimisers.jl") include("train.jl") end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl deleted file mode 100644 index f60589273d..0000000000 --- a/src/optimise/optimisers.jl +++ /dev/null @@ -1,672 +0,0 @@ -using Flux -using MacroTools: @forward - -const ϵ = 1e-8 - -# TODO: should use weak refs - -""" - Descent(η = 0.1) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp` - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = Descent() - -opt = Descent(0.3) - -ps = params(model) - -gs = gradient(ps) do - loss(x, y) -end - -Flux.Optimise.update!(opt, ps, gs) -``` -""" -mutable struct Descent - eta::Float64 -end - -Descent() = Descent(0.1) - -function apply!(o::Descent, x, Δ) - Δ .*= o.eta -end - -""" - Momentum(η = 0.01, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = Momentum() - -opt = Momentum(0.01, 0.99) -``` -""" -mutable struct Momentum - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) - -function apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v -end - -""" - Nesterov(η = 0.001, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = Nesterov() - -opt = Nesterov(0.003, 0.95) -``` -""" -mutable struct Nesterov - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) - -function apply!(o::Nesterov, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - d = @. ρ^2 * v - (1+ρ) * η * Δ - @. v = ρ*v - η*Δ - @. Δ = -d -end - -""" - RMSProp(η = 0.001, ρ = 0.9) - -Optimizer using the -[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -algorithm. Often a good choice for recurrent networks. Parameters other than learning rate -generally don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = RMSProp() - -opt = RMSProp(0.002, 0.95) -``` -""" -mutable struct RMSProp - eta::Float64 - rho::Float64 - acc::IdDict -end - -RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) - -function apply!(o::RMSProp, x, Δ) - η, ρ = o.eta, o.rho - acc = get!(() -> zero(x), o.acc, x)::typeof(x) - @. acc = ρ * acc + (1 - ρ) * Δ^2 - @. Δ *= η / (√acc + ϵ) -end - -""" - ADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[ADAM](https://arxiv.org/abs/1412.6980) optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = ADAM() - -opt = ADAM(0.001, (0.9, 0.8)) -``` -""" -mutable struct ADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) - -function apply!(o::ADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η - βp .= βp .* β - - return Δ -end - -""" - RADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[Rectified ADAM](https://arxiv.org/abs/1908.03265) optimizer. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = RADAM() - -opt = RADAM(0.001, (0.9, 0.8)) -``` -""" -mutable struct RADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict()) - -function apply!(o::RADAM, x, Δ) - η, β = o.eta, o.beta - ρ∞ = 2/(1-β[2])-1 - - mt, vt, βp, t = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]], Ref(1)) - end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) - if ρ > 4 - r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r - else - @. Δ = mt / (1 - βp[1]) * η - end - βp .= βp .* β - t[] += 1 - - return Δ -end - -""" - AdaMax(η = 0.001, β::Tuple = (0.9, 0.999)) - -[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of ADAM based on the ∞-norm. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaMax() - -opt = AdaMax(0.001, (0.9, 0.995)) -``` -""" -mutable struct AdaMax - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) - -function apply!(o::AdaMax, x, Δ) - η, β = o.eta, o.beta - - mt, ut, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. ut = max(β[2] * ut, abs(Δ)) - @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ) - βp .= βp .* β - - return Δ -end - -""" - OADAM(η = 0.0001, β::Tuple = (0.5, 0.9)) - -[OADAM](https://arxiv.org/abs/1711.00141) (Optimistic ADAM) -is a variant of ADAM adding an "optimistic" term suitable for adversarial training. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = OADAM() - -opt = OADAM(0.001, (0.9, 0.995)) -``` -""" -mutable struct OADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -OADAM(η = 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict()) - -function apply!(o::OADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, Δ_, βp = get!(o.state, x) do - (zero(x), zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = -Δ_ - @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) - @. Δ += 2Δ_ - βp .= βp .* β - - return Δ -end - -""" - ADAGrad(η = 0.1) - -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has -parameter specific learning rates based on how frequently it is updated. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = ADAGrad() - -opt = ADAGrad(0.001) -``` -""" -mutable struct ADAGrad - eta::Float64 - acc::IdDict -end - -ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) - -function apply!(o::ADAGrad, x, Δ) - η = o.eta - acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x) - @. acc += Δ^2 - @. Δ *= η / (√acc + ϵ) -end - -""" - ADADelta(ρ = 0.9) - -[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning -rate based on a window of past gradient updates. -Parameters don't need tuning. - -# Parameters -- Rho (`ρ`): Factor by which the gradient is decayed at each time step. - -# Examples -```julia -opt = ADADelta() - -opt = ADADelta(0.89) -``` -""" -mutable struct ADADelta - rho::Float64 - state::IdDict -end - -ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) - -function apply!(o::ADADelta, x, Δ) - ρ = o.rho - acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. acc = ρ * acc + (1 - ρ) * Δ^2 - # DON'T remove epsilon from numerator - # or even out of the square roots - @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 - return Δ -end - -""" - AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999)) - -The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM -optimiser. Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AMSGrad() - -opt = AMSGrad(0.001, (0.89, 0.995)) -``` -""" -mutable struct AMSGrad - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict -end - -AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) - -function apply!(o::AMSGrad, x, Δ) - η, β = o.eta, o.beta - - mt, vt, v̂t = get!(o.state, x) do - (fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ)) - end :: NTuple{3,typeof(x)} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 - @. v̂t = max(v̂t, vt) - @. Δ = η * mt / (√v̂t + ϵ) -end - -""" - NADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[NADAM](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of ADAM. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = NADAM() - -opt = NADAM(0.002, (0.89, 0.995)) -``` -""" -mutable struct NADAM - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict -end - -NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) - -function apply!(o::NADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[o.beta[1], o.beta[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - β1p, β2p = βp - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η - βp .= βp .* β - - return Δ -end - -""" - ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) - -[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its -weight decay regularization. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. -- `decay`: Decay applied to weights during optimisation. - -# Examples -```julia -opt = ADAMW() - -opt = ADAMW(0.001, (0.89, 0.995), 0.1) -``` -""" -ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(ADAM(η, β), WeightDecay(decay)) - -""" - AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999)) - -The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known -ADAM optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaBelief() - -opt = AdaBelief(0.001, (0.9, 0.8)) -``` -""" -mutable struct AdaBelief - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -AdaBelief(η = 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict()) - -function apply!(o::AdaBelief, x, Δ) - η, β = o.eta, o.beta - mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. st = β[2] * st + (1 - β[2]) * (Δ - mt)^2 - @. Δ = η * mt / (√(st) + ϵ) - return Δ -end - - -# Compose optimizers - -""" - Optimiser(a, b, c...) - -Combine several optimisers into one; each optimiser produces a modified gradient -that will be fed into the next, and this is finally applied to the parameter as -usual. -""" -mutable struct Optimiser - os::Vector{Any} -end - -Optimiser(o...) = Optimiser(Any[o...]) - -@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! -@forward Optimiser.os Base.iterate - -Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) - -function apply!(o::Optimiser, x, Δ) - for opt in o.os - Δ = apply!(opt, x, Δ) - end - return Δ -end - -""" - InvDecay(γ = 0.001) - -Apply inverse time decay to an optimiser, so that the effective step size at -iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. -The wrapped optimiser's step size is not modified. - -# Examples -```julia -Optimiser(InvDecay(..), Opt(..)) -``` -""" -mutable struct InvDecay - gamma::Float64 - state::IdDict -end - -InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) - -function apply!(o::InvDecay, x, Δ) - γ = o.gamma - n = get!(o.state, x, 1) - Δ .*= 1 / (1 + γ * n) - o.state[x] = n + 1 - return Δ -end - -""" - ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) - -Discount the learning rate `η` by the factor `decay` every `decay_step` steps till -a minimum of `clip`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- `decay`: Factor by which the learning rate is discounted. -- `decay_step`: Schedule decay operations by setting the number of steps between - two decay operations. -- `clip`: Minimum value of learning rate. - -# Examples -To apply exponential decay to an optimiser: -```julia -Optimiser(ExpDecay(..), Opt(..)) - -opt = Optimiser(ExpDecay(), ADAM()) -``` -""" -mutable struct ExpDecay - eta::Float64 - decay::Float64 - step::Int64 - clip::Float64 - current::IdDict -end - -ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) - -function apply!(o::ExpDecay, x, Δ) - η, s, decay = o.eta, o.step, o.decay - n = o.current[x] = get(o.current, x, 0) + 1 - if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 - η = max(η * decay, o.clip) - o.eta = η - end - @. Δ *= η -end - -""" - WeightDecay(wd = 0) - -Decay weights by `wd`. - -# Parameters -- Weight decay (`wd`) -""" -mutable struct WeightDecay - wd::Real -end - -WeightDecay() = WeightDecay(0) - -function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x -end - -""" - ClipValue(thresh) - -Clip gradients when their absolute value exceeds `thresh`. -""" -mutable struct ClipValue{T} - thresh::T -end - -apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) - -""" - ClipNorm(thresh) - -Clip gradients when their L2 norm exceeds `thresh`. -""" -mutable struct ClipNorm{T} - thresh::T -end - -function apply!(o::ClipNorm, x, Δ) - Δnrm = norm(Δ) - if Δnrm > o.thresh - rmul!(Δ, o.thresh / Δnrm) - end - return Δ -end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index d487032ddf..9db9e30620 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -19,15 +19,20 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. """ -function update!(opt, x, x̄) - x .-= apply!(opt, x, x̄) +function update!(opt, x, x̄, st) + x̄, st = apply(opt, x, x̄, st) + x .-= x̄ + x̄, st end -function update!(opt, xs::Params, gs) - for x in xs +function update!(opt, xs::Params, gs, st) + st_ = [] + for (x,s) in zip(xs,st) gs[x] == nothing && continue - update!(opt, x, gs[x]) + _, s = update!(opt, x, gs[x], s) + append!(st_, s) end + xs, st_ end # Callback niceties @@ -74,6 +79,8 @@ function stop() throw(StopException()) end +Optimisers.init(o, ps::Params) = [Optimisers.init(o,p) for p in ps] + batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x @@ -97,12 +104,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) + st = Optimisers.init(opt, ps) @progress for d in data try gs = gradient(ps) do loss(batchmemaybe(d)...) end - update!(opt, ps, gs) + _, st = update!(opt, ps, gs, st) cb() catch ex if ex isa StopException diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..3041009a4e 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,6 +1,7 @@ using Flux.Optimise using Flux.Optimise: runall using Flux: Params, gradient +using Flux.Optimise: Optimisers using Test using Random @@ -9,41 +10,57 @@ using Random # so that w and w' are different Random.seed!(84) w = randn(10, 10) - @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), - NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), AdaBelief(), + @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(), AMSGrad(), + NADAM(), RADAM(), Descent(), ADAM(), OADAM(), AdaBelief(), Nesterov(), RMSProp(), Momentum()] Random.seed!(42) w′ = randn(10, 10) b = Flux.Zeros() loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + st = [Optimisers.init(opt, p) for p in [w′, b]] for t = 1: 10^5 θ = params([w′, b]) x = rand(10) θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) + _, st = Flux.Optimise.update!(opt, θ, θ̄, st) end @test loss(rand(10, 10)) < 0.01 end -end -@testset "Optimiser" begin Random.seed!(84) - w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay] - Random.seed!(42) - w′ = randn(10, 10) - loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), ADAM(0.001)) - for t = 1:10^5 - θ = Params([w′]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) + w′ = rand(3,3) + @testset for o in (Descent(0.1), Momentum(0.01, 0.9), Nesterov(0.001, 0.9), RMSProp(0.001, 0.9), + ADAM(0.001, (0.9, 0.99))) + w = rand(3,3) + st = Flux.Optimise.Optimisers.init(o,w) + loss(x, y) = mean((x .- y) .^ 2) + l = loss(w, w′) + for i = 1:10^4 + gs = gradient(x -> loss(x,w′), w) + w, st = o(w, gs..., st) end - @test loss(rand(10, 10)) < 0.01 + @test loss(w, w′) < 0.01 end end +# @testset "Optimiser" begin +# Random.seed!(84) +# w = randn(10, 10) +# @testset for Opt in [InvDecay, WeightDecay, ExpDecay] +# Random.seed!(42) +# w′ = randn(10, 10) +# loss(x) = Flux.Losses.mse(w*x, w′*x) +# opt = Optimiser(Opt(), ADAM(0.001)) +# for t = 1:10^5 +# θ = Params([w′]) +# x = rand(10) +# θ̄ = gradient(() -> loss(x), θ) +# Optimise.update!(opt, θ, θ̄) +# end +# @test loss(rand(10, 10)) < 0.01 +# end +# end + @testset "Training Loop" begin i = 0 l = 1 @@ -86,58 +103,59 @@ end Flux.train!(loss, Flux.params(r), (r,), Descent()) end -@testset "ExpDecay" begin - - @testset "Sanity Check" begin - o = ExpDecay(0.2, 0.5, 1, 1e-3) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - w = randn(10, 10) - o = ExpDecay(0.1, 0.1, 1000, 1e-4) - w1 = randn(10,10) - loss(x) = Flux.Losses.mse(w*x, w1*x) - flag = 1 - decay_steps = [] - for t = 1:10^5 - prev_eta = o.eta - θ = Params([w1]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - prev_grad = collect(θ̄[w1]) - delta = Optimise.apply!(o, w1, θ̄[w1]) - w1 .-= delta - new_eta = o.eta - if new_eta != prev_eta - push!(decay_steps, t) - end - array = fill(o.eta, size(prev_grad)) - if array .* prev_grad != delta - flag = 0 - end - end - @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). - ground_truth = [] - for i in 1:4 - push!(ground_truth, 1000*i) # Expected decay steps for this example. - end - @test decay_steps == ground_truth - @test o.eta == o.clip -end - -@testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 -end +# @testset "ExpDecay" begin +# +# @testset "Sanity Check" begin +# o = ExpDecay(0.2, 0.5, 1, 1e-3) +# p = [0.0] +# steps = 1:8 +# eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) +# eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] +# @test eta_actual == eta_expected +# end +# +# w = randn(10, 10) +# o = ExpDecay(0.1, 0.1, 1000, 1e-4) +# w1 = randn(10,10) +# loss(x) = Flux.Losses.mse(w*x, w1*x) +# flag = 1 +# decay_steps = [] +# for t = 1:10^5 +# prev_eta = o.eta +# θ = Params([w1]) +# x = rand(10) +# θ̄ = gradient(() -> loss(x), θ) +# prev_grad = collect(θ̄[w1]) +# delta = Optimise.apply!(o, w1, θ̄[w1]) +# w1 .-= delta +# new_eta = o.eta +# if new_eta != prev_eta +# push!(decay_steps, t) +# end +# array = fill(o.eta, size(prev_grad)) +# if array .* prev_grad != delta +# flag = 0 +# end +# end +# @test flag == 1 +# # Test to check if decay happens at decay steps. Eta reaches +# # clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). +# ground_truth = [] +# for i in 1:4 +# push!(ground_truth, 1000*i) # Expected decay steps for this example. +# end +# @test decay_steps == ground_truth +# @test o.eta == o.clip +# end +# +# @testset "Clipping" begin +# w = randn(10, 10) +# loss(x) = sum(w * x) +# θ = Params([w]) +# x = 1000 * randn(10) +# w̄ = gradient(() -> loss(x), θ)[w] +# w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) +# @test all(w̄_value .<= 1) +# w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) +# @test norm(w̄_norm) <= 1 +# end