From de65314729ed7de3e356418e0ede39a2969eac05 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 25 Jan 2021 01:35:17 +0530 Subject: [PATCH 01/25] add Optimisers --- src/optimise/Optimise.jl | 3 +- src/optimise/optimisers.jl | 154 +++++++++++++++++++------------------ src/optimise/train.jl | 102 +++++++++++++++++------- 3 files changed, 156 insertions(+), 103 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e2485a05d0..f8cc27c28a 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,9 +1,10 @@ module Optimise using LinearAlgebra +using Optimisers +using Optimisers: Descent, ADAM, Momentum, Nesterov, RMSProp export train!, update!, - Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, ClipValue, ClipNorm diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index f60589273d..0ceb45ae88 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -30,15 +30,16 @@ end Flux.Optimise.update!(opt, ps, gs) ``` """ -mutable struct Descent - eta::Float64 -end - -Descent() = Descent(0.1) - -function apply!(o::Descent, x, Δ) - Δ .*= o.eta -end +Descent +# mutable struct Descent +# eta::Float64 +# end +# +# Descent() = Descent(0.1) +# +# function apply!(o::Descent, x, Δ) +# Δ .*= o.eta +# end """ Momentum(η = 0.01, ρ = 0.9) @@ -58,20 +59,21 @@ opt = Momentum() opt = Momentum(0.01, 0.99) ``` """ -mutable struct Momentum - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) - -function apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v -end +Momentum +# mutable struct Momentum +# eta::Float64 +# rho::Float64 +# velocity::IdDict +# end +# +# Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) +# +# function apply!(o::Momentum, x, Δ) +# η, ρ = o.eta, o.rho +# v = get!(() -> zero(x), o.velocity, x)::typeof(x) +# @. v = ρ * v - η * Δ +# @. Δ = -v +# end """ Nesterov(η = 0.001, ρ = 0.9) @@ -91,21 +93,22 @@ opt = Nesterov() opt = Nesterov(0.003, 0.95) ``` """ -mutable struct Nesterov - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) - -function apply!(o::Nesterov, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - d = @. ρ^2 * v - (1+ρ) * η * Δ - @. v = ρ*v - η*Δ - @. Δ = -d -end +Nesterov +# mutable struct Nesterov +# eta::Float64 +# rho::Float64 +# velocity::IdDict +# end +# +# Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) +# +# function apply!(o::Nesterov, x, Δ) +# η, ρ = o.eta, o.rho +# v = get!(() -> zero(x), o.velocity, x)::typeof(x) +# d = @. ρ^2 * v - (1+ρ) * η * Δ +# @. v = ρ*v - η*Δ +# @. Δ = -d +# end """ RMSProp(η = 0.001, ρ = 0.9) @@ -128,20 +131,21 @@ opt = RMSProp() opt = RMSProp(0.002, 0.95) ``` """ -mutable struct RMSProp - eta::Float64 - rho::Float64 - acc::IdDict -end - -RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) - -function apply!(o::RMSProp, x, Δ) - η, ρ = o.eta, o.rho - acc = get!(() -> zero(x), o.acc, x)::typeof(x) - @. acc = ρ * acc + (1 - ρ) * Δ^2 - @. Δ *= η / (√acc + ϵ) -end +RMSProp +# mutable struct RMSProp +# eta::Float64 +# rho::Float64 +# acc::IdDict +# end +# +# RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) +# +# function apply!(o::RMSProp, x, Δ) +# η, ρ = o.eta, o.rho +# acc = get!(() -> zero(x), o.acc, x)::typeof(x) +# @. acc = ρ * acc + (1 - ρ) * Δ^2 +# @. Δ *= η / (√acc + ϵ) +# end """ ADAM(η = 0.001, β::Tuple = (0.9, 0.999)) @@ -161,28 +165,30 @@ opt = ADAM() opt = ADAM(0.001, (0.9, 0.8)) ``` """ -mutable struct ADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) - -function apply!(o::ADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η - βp .= βp .* β - - return Δ -end +ADAM + +# mutable struct ADAM +# eta::Float64 +# beta::Tuple{Float64,Float64} +# state::IdDict +# end +# +# ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) +# +# function apply!(o::ADAM, x, Δ) +# η, β = o.eta, o.beta +# +# mt, vt, βp = get!(o.state, x) do +# (zero(x), zero(x), Float64[β[1], β[2]]) +# end :: Tuple{typeof(x),typeof(x),Vector{Float64}} +# +# @. mt = β[1] * mt + (1 - β[1]) * Δ +# @. vt = β[2] * vt + (1 - β[2]) * Δ^2 +# @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η +# βp .= βp .* β +# +# return Δ +# end """ RADAM(η = 0.001, β::Tuple = (0.9, 0.999)) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index d487032ddf..37501dcbfd 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,34 +1,34 @@ using Juno import Zygote: Params, gradient -""" - update!(x, x̄) - -Update the array `x` according to `x .-= x̄`. -""" -function update!(x::AbstractArray, x̄) - x .-= x̄ -end - -""" - update!(opt, p, g) - update!(opt, ps::Params, gs) - -Perform an update step of the parameters `ps` (or the single parameter `p`) -according to optimizer `opt` and the gradients `gs` (the gradient `g`). - -As a result, the parameters are mutated and the optimizer's internal state may change. -""" -function update!(opt, x, x̄) - x .-= apply!(opt, x, x̄) -end - -function update!(opt, xs::Params, gs) - for x in xs - gs[x] == nothing && continue - update!(opt, x, gs[x]) - end -end +# """ +# update!(x, x̄) +# +# Update the array `x` according to `x .-= x̄`. +# """ +# function update!(x::AbstractArray, x̄) +# x .-= x̄ +# end +# +# """ +# update!(opt, p, g) +# update!(opt, ps::Params, gs) +# +# Perform an update step of the parameters `ps` (or the single parameter `p`) +# according to optimizer `opt` and the gradients `gs` (the gradient `g`). +# +# As a result, the parameters are mutated and the optimizer's internal state may change. +# """ +# function update!(opt, x, x̄) +# x .-= apply!(opt, x, x̄) +# end +# +# function update!(opt, xs::Params, gs) +# for x in xs +# gs[x] == nothing && continue +# update!(opt, x, gs[x]) +# end +# end # Callback niceties call(f, xs...) = f(xs...) @@ -97,6 +97,7 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) + st = Optimisers.init(opt, ps) @progress for d in data try gs = gradient(ps) do @@ -116,6 +117,43 @@ function train!(loss, ps, data, opt; cb = () -> ()) end end +function train!(m, loss, data, opt; cb = (x...) -> () + prehook = (x...) -> (), + posthook = (x...) -> ()) + st = [Optimisers.init(opt, p) for p in Flux.params(m)] + prehook = runall(prehook) + posthook = runall(posthook) + cb = runall(cb) + + dlen = try + length(data) + catch e + @warn "Dataset length unkown" + 0 + end + + for d in data + try + ŷ, back = pullback(m) do m + loss(m, batchmemaybe(d)...) + end + prehook(ŷ) + m̂, = back(Zygote.sensitivity(ŷ)) + posthook(ŷ, m, m̂) + m, st = opt(m, m̂, st) + cb(ŷ, m, m̂) + catch ex + if ex isa StopException + break + elseif ex isa SkipException + continue + else + rethrow(ex) + end + end + end +end + """ @epochs N body @@ -137,3 +175,11 @@ macro epochs(n, ex) $(esc(ex)) end) end + +macro epochs(n, f, ex) + :(@progress for i = 1:$(esc(n)) + @info "Epoch $i" + $(esc(ex)) + $(esc(f($i))) + end) +end From 14abfa9aecdf0b3f6d375bc4f307ae83cab8d894 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 25 Jan 2021 01:35:42 +0530 Subject: [PATCH 02/25] add Optimisers to env --- Manifest.toml | 8 ++++++++ Project.toml | 1 + 2 files changed, 9 insertions(+) diff --git a/Manifest.toml b/Manifest.toml index b5496ecad7..e700c35477 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -256,6 +256,14 @@ git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.3+4" +[[Optimisers]] +deps = ["Functors", "Random", "Requires", "Statistics"] +git-tree-sha1 = "3130cebce66aed0943bdd8f3d61c641e44d7ea0e" +repo-rev = "master" +repo-url = "https://github.com/FluxML/Optimisers.jl.git" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.1.0" + [[OrderedCollections]] git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" diff --git a/Project.toml b/Project.toml index feaaf39d8d..e313e4e40f 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From 07d647cd8f20453d9cd4fed913c3ea20b345cea2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 25 Jan 2021 01:43:33 +0530 Subject: [PATCH 03/25] go through some tests --- test/optimise.jl | 76 +++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/test/optimise.jl b/test/optimise.jl index 04cbf6f6c0..40db490f57 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -9,41 +9,56 @@ using Random # so that w and w' are different Random.seed!(84) w = randn(10, 10) - @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), - NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), AdaBelief(), - Nesterov(), RMSProp(), Momentum()] - Random.seed!(42) - w′ = randn(10, 10) - b = Flux.Zeros() - loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - for t = 1: 10^5 - θ = params([w′, b]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end + # @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), + # NADAM(), RADAM(), Descent(), ADAM(), OADAM(), AdaBelief(), + # Nesterov(), RMSProp(), Momentum()] + # Random.seed!(42) + # w′ = randn(10, 10) + # b = Flux.Zeros() + # loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + # for t = 1: 10^5 + # θ = params([w′, b]) + # x = rand(10) + # θ̄ = gradient(() -> loss(x), θ) + # Optimise.update!(opt, θ, θ̄) + # end + # @test loss(rand(10, 10)) < 0.01 + # end -@testset "Optimiser" begin Random.seed!(84) - w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay] - Random.seed!(42) - w′ = randn(10, 10) - loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), ADAM(0.001)) - for t = 1:10^5 - θ = Params([w′]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) + w′ = rand(3,3) + @testset for o in (Descent(0.1), Momentum(0.01, 0.9), Nesterov(0.001, 0.9), RMSProp(0.001, 0.9), + ADAM(0.001, (0.9, 0.99))) + w = rand(3,3) + st = Optimisers.init(o,w) + loss(x, y) = mean((x .- y) .^ 2) + l = loss(w, w′) + for i = 1:10^4 + gs = gradient(x -> loss(x,w′), w) + w, st = o(w, gs..., st) end - @test loss(rand(10, 10)) < 0.01 + @test loss(w, w′) < 0.01 end end +# @testset "Optimiser" begin +# Random.seed!(84) +# w = randn(10, 10) +# @testset for Opt in [InvDecay, WeightDecay, ExpDecay] +# Random.seed!(42) +# w′ = randn(10, 10) +# loss(x) = Flux.Losses.mse(w*x, w′*x) +# opt = Optimiser(Opt(), ADAM(0.001)) +# for t = 1:10^5 +# θ = Params([w′]) +# x = rand(10) +# θ̄ = gradient(() -> loss(x), θ) +# Optimise.update!(opt, θ, θ̄) +# end +# @test loss(rand(10, 10)) < 0.01 +# end +# end + @testset "Training Loop" begin i = 0 l = 1 @@ -121,7 +136,8 @@ end end end @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). + # Test to check if decay happens at decay steps. Eta reaches + # clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). ground_truth = [] for i in 1:4 push!(ground_truth, 1000*i) # Expected decay steps for this example. From 0328aab7e4175bb10a246d5454d9830fcb449c05 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 26 Jan 2021 13:30:55 +0530 Subject: [PATCH 04/25] general runall --- src/optimise/train.jl | 68 ++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 37501dcbfd..d93e715897 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,5 @@ using Juno -import Zygote: Params, gradient +import Zygote: Params, pullback, gradient, sensitivity # """ # update!(x, x̄) @@ -33,7 +33,7 @@ import Zygote: Params, gradient # Callback niceties call(f, xs...) = f(xs...) runall(f) = f -runall(fs::AbstractVector) = () -> foreach(call, fs) +runall(fs::AbstractVector) = (x...) -> foreach(f -> call(f, x...), fs) struct SkipException <: Exception end @@ -94,30 +94,30 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(loss, ps, data, opt; cb = () -> ()) - ps = Params(ps) - cb = runall(cb) - st = Optimisers.init(opt, ps) - @progress for d in data - try - gs = gradient(ps) do - loss(batchmemaybe(d)...) - end - update!(opt, ps, gs) - cb() - catch ex - if ex isa StopException - break - elseif ex isa SkipException - continue - else - rethrow(ex) - end - end - end -end +# function train!(loss, ps, data, opt; cb = () -> ()) +# ps = Params(ps) +# cb = runall(cb) +# st = Optimisers.init(opt, ps) +# @progress for d in data +# try +# gs = gradient(ps) do +# loss(batchmemaybe(d)...) +# end +# update!(opt, ps, gs) +# cb() +# catch ex +# if ex isa StopException +# break +# elseif ex isa SkipException +# continue +# else +# rethrow(ex) +# end +# end +# end +# end -function train!(m, loss, data, opt; cb = (x...) -> () +function train(m, loss, data, opt; cb = (x...) -> (), prehook = (x...) -> (), posthook = (x...) -> ()) st = [Optimisers.init(opt, p) for p in Flux.params(m)] @@ -138,7 +138,7 @@ function train!(m, loss, data, opt; cb = (x...) -> () loss(m, batchmemaybe(d)...) end prehook(ŷ) - m̂, = back(Zygote.sensitivity(ŷ)) + m̂, = back(sensitivity(ŷ)) posthook(ŷ, m, m̂) m, st = opt(m, m̂, st) cb(ŷ, m, m̂) @@ -152,6 +152,8 @@ function train!(m, loss, data, opt; cb = (x...) -> () end end end + + m end """ @@ -176,10 +178,10 @@ macro epochs(n, ex) end) end -macro epochs(n, f, ex) - :(@progress for i = 1:$(esc(n)) - @info "Epoch $i" - $(esc(ex)) - $(esc(f($i))) - end) -end +# macro epochs(n, f, ex) +# :(@progress for i = 1:$(esc(n)) +# @info "Epoch $i" +# $(esc(ex)) +# $(esc(f($i))) +# end) +# end From 91a1de0cb3dc78a36b6b743ec319c9596b47ecd7 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 26 Jan 2021 13:31:22 +0530 Subject: [PATCH 05/25] pkg up --- Manifest.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Manifest.toml b/Manifest.toml index e700c35477..4548c66d3f 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -152,7 +152,9 @@ version = "0.10.14" [[Functors]] deps = ["MacroTools"] -git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a" +git-tree-sha1 = "5b082c8f4507e15d435f7727efc9d7006f954140" +repo-rev = "dg/grad" +repo-url = "https://github.com/FluxML/Functors.jl.git" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" version = "0.1.0" From fabd6eb3b333c5cfb8550c114dfabb21785c90c2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:23:01 +0530 Subject: [PATCH 06/25] rm optimisers --- src/optimise/optimisers.jl | 678 ------------------------------------- 1 file changed, 678 deletions(-) delete mode 100644 src/optimise/optimisers.jl diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl deleted file mode 100644 index 0ceb45ae88..0000000000 --- a/src/optimise/optimisers.jl +++ /dev/null @@ -1,678 +0,0 @@ -using Flux -using MacroTools: @forward - -const ϵ = 1e-8 - -# TODO: should use weak refs - -""" - Descent(η = 0.1) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp` - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = Descent() - -opt = Descent(0.3) - -ps = params(model) - -gs = gradient(ps) do - loss(x, y) -end - -Flux.Optimise.update!(opt, ps, gs) -``` -""" -Descent -# mutable struct Descent -# eta::Float64 -# end -# -# Descent() = Descent(0.1) -# -# function apply!(o::Descent, x, Δ) -# Δ .*= o.eta -# end - -""" - Momentum(η = 0.01, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = Momentum() - -opt = Momentum(0.01, 0.99) -``` -""" -Momentum -# mutable struct Momentum -# eta::Float64 -# rho::Float64 -# velocity::IdDict -# end -# -# Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) -# -# function apply!(o::Momentum, x, Δ) -# η, ρ = o.eta, o.rho -# v = get!(() -> zero(x), o.velocity, x)::typeof(x) -# @. v = ρ * v - η * Δ -# @. Δ = -v -# end - -""" - Nesterov(η = 0.001, ρ = 0.9) - -Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = Nesterov() - -opt = Nesterov(0.003, 0.95) -``` -""" -Nesterov -# mutable struct Nesterov -# eta::Float64 -# rho::Float64 -# velocity::IdDict -# end -# -# Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) -# -# function apply!(o::Nesterov, x, Δ) -# η, ρ = o.eta, o.rho -# v = get!(() -> zero(x), o.velocity, x)::typeof(x) -# d = @. ρ^2 * v - (1+ρ) * η * Δ -# @. v = ρ*v - η*Δ -# @. Δ = -d -# end - -""" - RMSProp(η = 0.001, ρ = 0.9) - -Optimizer using the -[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -algorithm. Often a good choice for recurrent networks. Parameters other than learning rate -generally don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect dampening oscillations. - -# Examples -```julia -opt = RMSProp() - -opt = RMSProp(0.002, 0.95) -``` -""" -RMSProp -# mutable struct RMSProp -# eta::Float64 -# rho::Float64 -# acc::IdDict -# end -# -# RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) -# -# function apply!(o::RMSProp, x, Δ) -# η, ρ = o.eta, o.rho -# acc = get!(() -> zero(x), o.acc, x)::typeof(x) -# @. acc = ρ * acc + (1 - ρ) * Δ^2 -# @. Δ *= η / (√acc + ϵ) -# end - -""" - ADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[ADAM](https://arxiv.org/abs/1412.6980) optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = ADAM() - -opt = ADAM(0.001, (0.9, 0.8)) -``` -""" -ADAM - -# mutable struct ADAM -# eta::Float64 -# beta::Tuple{Float64,Float64} -# state::IdDict -# end -# -# ADAM(η = 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict()) -# -# function apply!(o::ADAM, x, Δ) -# η, β = o.eta, o.beta -# -# mt, vt, βp = get!(o.state, x) do -# (zero(x), zero(x), Float64[β[1], β[2]]) -# end :: Tuple{typeof(x),typeof(x),Vector{Float64}} -# -# @. mt = β[1] * mt + (1 - β[1]) * Δ -# @. vt = β[2] * vt + (1 - β[2]) * Δ^2 -# @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η -# βp .= βp .* β -# -# return Δ -# end - -""" - RADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[Rectified ADAM](https://arxiv.org/abs/1908.03265) optimizer. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = RADAM() - -opt = RADAM(0.001, (0.9, 0.8)) -``` -""" -mutable struct RADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -RADAM(η = 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict()) - -function apply!(o::RADAM, x, Δ) - η, β = o.eta, o.beta - ρ∞ = 2/(1-β[2])-1 - - mt, vt, βp, t = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]], Ref(1)) - end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) - if ρ > 4 - r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) * η * r - else - @. Δ = mt / (1 - βp[1]) * η - end - βp .= βp .* β - t[] += 1 - - return Δ -end - -""" - AdaMax(η = 0.001, β::Tuple = (0.9, 0.999)) - -[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of ADAM based on the ∞-norm. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaMax() - -opt = AdaMax(0.001, (0.9, 0.995)) -``` -""" -mutable struct AdaMax - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -AdaMax(η = 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict()) - -function apply!(o::AdaMax, x, Δ) - η, β = o.eta, o.beta - - mt, ut, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. ut = max(β[2] * ut, abs(Δ)) - @. Δ = (η/(1 - βp[1])) * mt/(ut + ϵ) - βp .= βp .* β - - return Δ -end - -""" - OADAM(η = 0.0001, β::Tuple = (0.5, 0.9)) - -[OADAM](https://arxiv.org/abs/1711.00141) (Optimistic ADAM) -is a variant of ADAM adding an "optimistic" term suitable for adversarial training. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = OADAM() - -opt = OADAM(0.001, (0.9, 0.995)) -``` -""" -mutable struct OADAM - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -OADAM(η = 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict()) - -function apply!(o::OADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, Δ_, βp = get!(o.state, x) do - (zero(x), zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = -Δ_ - @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ) - @. Δ += 2Δ_ - βp .= βp .* β - - return Δ -end - -""" - ADAGrad(η = 0.1) - -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimizer. It has -parameter specific learning rates based on how frequently it is updated. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = ADAGrad() - -opt = ADAGrad(0.001) -``` -""" -mutable struct ADAGrad - eta::Float64 - acc::IdDict -end - -ADAGrad(η = 0.1) = ADAGrad(η, IdDict()) - -function apply!(o::ADAGrad, x, Δ) - η = o.eta - acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x) - @. acc += Δ^2 - @. Δ *= η / (√acc + ϵ) -end - -""" - ADADelta(ρ = 0.9) - -[ADADelta](https://arxiv.org/abs/1212.5701) is a version of ADAGrad adapting its learning -rate based on a window of past gradient updates. -Parameters don't need tuning. - -# Parameters -- Rho (`ρ`): Factor by which the gradient is decayed at each time step. - -# Examples -```julia -opt = ADADelta() - -opt = ADADelta(0.89) -``` -""" -mutable struct ADADelta - rho::Float64 - state::IdDict -end - -ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict()) - -function apply!(o::ADADelta, x, Δ) - ρ = o.rho - acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. acc = ρ * acc + (1 - ρ) * Δ^2 - # DON'T remove epsilon from numerator - # or even out of the square roots - @. Δ *= √(Δacc + ϵ) / √(acc + ϵ) - @. Δacc = ρ * Δacc + (1 - ρ) * Δ^2 - return Δ -end - -""" - AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999)) - -The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the ADAM -optimiser. Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AMSGrad() - -opt = AMSGrad(0.001, (0.89, 0.995)) -``` -""" -mutable struct AMSGrad - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict -end - -AMSGrad(η = 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict()) - -function apply!(o::AMSGrad, x, Δ) - η, β = o.eta, o.beta - - mt, vt, v̂t = get!(o.state, x) do - (fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ)) - end :: NTuple{3,typeof(x)} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 - @. v̂t = max(v̂t, vt) - @. Δ = η * mt / (√v̂t + ϵ) -end - -""" - NADAM(η = 0.001, β::Tuple = (0.9, 0.999)) - -[NADAM](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of ADAM. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = NADAM() - -opt = NADAM(0.002, (0.89, 0.995)) -``` -""" -mutable struct NADAM - eta::Float64 - beta::Tuple{Float64, Float64} - state::IdDict -end - -NADAM(η = 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict()) - -function apply!(o::NADAM, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[o.beta[1], o.beta[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - β1p, β2p = βp - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ^2 - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + ϵ) * η - βp .= βp .* β - - return Δ -end - -""" - ADAMW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) - -[ADAMW](https://arxiv.org/abs/1711.05101) is a variant of ADAM fixing (as in repairing) its -weight decay regularization. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. -- `decay`: Decay applied to weights during optimisation. - -# Examples -```julia -opt = ADAMW() - -opt = ADAMW(0.001, (0.89, 0.995), 0.1) -``` -""" -ADAMW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(ADAM(η, β), WeightDecay(decay)) - -""" - AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999)) - -The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known -ADAM optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaBelief() - -opt = AdaBelief(0.001, (0.9, 0.8)) -``` -""" -mutable struct AdaBelief - eta::Float64 - beta::Tuple{Float64,Float64} - state::IdDict -end - -AdaBelief(η = 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict()) - -function apply!(o::AdaBelief, x, Δ) - η, β = o.eta, o.beta - mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. st = β[2] * st + (1 - β[2]) * (Δ - mt)^2 - @. Δ = η * mt / (√(st) + ϵ) - return Δ -end - - -# Compose optimizers - -""" - Optimiser(a, b, c...) - -Combine several optimisers into one; each optimiser produces a modified gradient -that will be fed into the next, and this is finally applied to the parameter as -usual. -""" -mutable struct Optimiser - os::Vector{Any} -end - -Optimiser(o...) = Optimiser(Any[o...]) - -@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! -@forward Optimiser.os Base.iterate - -Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) - -function apply!(o::Optimiser, x, Δ) - for opt in o.os - Δ = apply!(opt, x, Δ) - end - return Δ -end - -""" - InvDecay(γ = 0.001) - -Apply inverse time decay to an optimiser, so that the effective step size at -iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. -The wrapped optimiser's step size is not modified. - -# Examples -```julia -Optimiser(InvDecay(..), Opt(..)) -``` -""" -mutable struct InvDecay - gamma::Float64 - state::IdDict -end - -InvDecay(γ = 0.001) = InvDecay(γ, IdDict()) - -function apply!(o::InvDecay, x, Δ) - γ = o.gamma - n = get!(o.state, x, 1) - Δ .*= 1 / (1 + γ * n) - o.state[x] = n + 1 - return Δ -end - -""" - ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) - -Discount the learning rate `η` by the factor `decay` every `decay_step` steps till -a minimum of `clip`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- `decay`: Factor by which the learning rate is discounted. -- `decay_step`: Schedule decay operations by setting the number of steps between - two decay operations. -- `clip`: Minimum value of learning rate. - -# Examples -To apply exponential decay to an optimiser: -```julia -Optimiser(ExpDecay(..), Opt(..)) - -opt = Optimiser(ExpDecay(), ADAM()) -``` -""" -mutable struct ExpDecay - eta::Float64 - decay::Float64 - step::Int64 - clip::Float64 - current::IdDict -end - -ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) = ExpDecay(opt, decay, decay_step, clip, IdDict()) - -function apply!(o::ExpDecay, x, Δ) - η, s, decay = o.eta, o.step, o.decay - n = o.current[x] = get(o.current, x, 0) + 1 - if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1 - η = max(η * decay, o.clip) - o.eta = η - end - @. Δ *= η -end - -""" - WeightDecay(wd = 0) - -Decay weights by `wd`. - -# Parameters -- Weight decay (`wd`) -""" -mutable struct WeightDecay - wd::Real -end - -WeightDecay() = WeightDecay(0) - -function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x -end - -""" - ClipValue(thresh) - -Clip gradients when their absolute value exceeds `thresh`. -""" -mutable struct ClipValue{T} - thresh::T -end - -apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) - -""" - ClipNorm(thresh) - -Clip gradients when their L2 norm exceeds `thresh`. -""" -mutable struct ClipNorm{T} - thresh::T -end - -function apply!(o::ClipNorm, x, Δ) - Δnrm = norm(Δ) - if Δnrm > o.thresh - rmul!(Δ, o.thresh / Δnrm) - end - return Δ -end From 693143860d11a4a9e77c004aa706f14bc219bad7 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:25:26 +0530 Subject: [PATCH 07/25] fix exports --- src/optimise/Optimise.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index f8cc27c28a..16404a6507 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,12 +2,11 @@ module Optimise using LinearAlgebra using Optimisers -using Optimisers: Descent, ADAM, Momentum, Nesterov, RMSProp -export train!, update!, - ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, - ClipValue, ClipNorm +export train!, + ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, + InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, + ClipValue, ClipNorm include("optimisers.jl") include("train.jl") From 2ae9b72a10c1b9ca7bb3875187e8e31eaef9cda3 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:32:09 +0530 Subject: [PATCH 08/25] use immutable apply in update --- src/optimise/train.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index d93e715897..eb0662036a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -19,14 +19,15 @@ import Zygote: Params, pullback, gradient, sensitivity # # As a result, the parameters are mutated and the optimizer's internal state may change. # """ -# function update!(opt, x, x̄) -# x .-= apply!(opt, x, x̄) +# function update!(opt, x, x̄, st) +# Δ, st = x .- apply(opt, x, x̄, st) +# update!(x, Δ) # end # -# function update!(opt, xs::Params, gs) +# function update!(opt, xs::Params, gs, st) # for x in xs # gs[x] == nothing && continue -# update!(opt, x, gs[x]) +# _, st = update!(opt, x, gs[x], st) # end # end From 2425ffb44a764fd78fe5152f1cd1a9e97bf09079 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:36:49 +0530 Subject: [PATCH 09/25] typo --- src/optimise/train.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index eb0662036a..18146a590e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -22,6 +22,7 @@ import Zygote: Params, pullback, gradient, sensitivity # function update!(opt, x, x̄, st) # Δ, st = x .- apply(opt, x, x̄, st) # update!(x, Δ) +# Δ, st # end # # function update!(opt, xs::Params, gs, st) From 50a0579698e0bcb3e2f160fcdb3fde50b4fe7e03 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 4 Feb 2021 16:11:40 +0530 Subject: [PATCH 10/25] add doc references --- docs/src/training/optimisers.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 53bb3f4345..af23cb43ed 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down. opt = Descent(0.1) # Gradient descent with learning rate 0.1 for p in (W, b) - update!(opt, p, grads[p]) + opt(p, grads[p]) end ``` @@ -70,23 +70,25 @@ Flux's optimisers are built around a `struct` that holds all the optimiser param In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example. ```julia -mutable struct Momentum +struct Momentum eta rho velocity end -Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict()) +Momentum(eta::Real, rho::Real) = Momentum(eta, rho) +Optimisers.init(opt::Momentum, x) = (zero(x),) ``` The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked. ```julia -function Flux.Optimise.apply!(o::Momentum, x, Δ) +function Flux.Optimise.apply(o::Momentum, x, Δ, st) η, ρ = o.eta, o.rho - v = get!(o.velocity, x, zero(x))::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v + v, = st + v = @. ρ * v - η * Δ + Δ = @. -v + Δ, (v,) end ``` @@ -97,9 +99,9 @@ v = ρ * v - η * Δ w = w - v ``` -The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser. +The `apply` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and returns the new state of the optimizer. -Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. +Flux internally calls on this function via the `update!` function. It shares the API with `apply` but ensures that multiple parameters are handled gracefully. ## Composing Optimisers From 312643fb77226bc24e0aba04b277780206f7c70e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 4 Feb 2021 16:16:28 +0530 Subject: [PATCH 11/25] fixes --- docs/src/training/optimisers.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index af23cb43ed..05947da26a 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -73,7 +73,6 @@ In this manner Flux also allows one to create custom optimisers to be used seaml struct Momentum eta rho - velocity end Momentum(eta::Real, rho::Real) = Momentum(eta, rho) @@ -123,11 +122,12 @@ ps = Params([w, w1]) loss(x) = Flux.Losses.mse(w * x, w1 * x) loss(rand(10)) # around 9 +st = Optimisers.init(opt, [w, w1]) for t = 1:10^5 - θ = Params([w, w1]) + θ = ps θ̄ = gradient(() -> loss(rand(10)), θ) - Flux.Optimise.update!(opt, θ, θ̄) + ps, st = Flux.Optimise.update!(opt, θ, θ̄, st) end loss(rand(10)) # around 0.9 From d4f27a719a7b8a9de148950ca3a1d4a582da39f1 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 8 Feb 2021 21:29:40 +0530 Subject: [PATCH 12/25] rm train changes --- src/optimise/train.jl | 133 +++++++++++------------------------------- 1 file changed, 34 insertions(+), 99 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 18146a590e..7b1708eac6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,50 +1,43 @@ using Juno -import Zygote: Params, pullback, gradient, sensitivity +import Zygote: Params, gradient -# """ -# update!(x, x̄) -# -# Update the array `x` according to `x .-= x̄`. -# """ -# function update!(x::AbstractArray, x̄) -# x .-= x̄ -# end -# -# """ -# update!(opt, p, g) -# update!(opt, ps::Params, gs) -# -# Perform an update step of the parameters `ps` (or the single parameter `p`) -# according to optimizer `opt` and the gradients `gs` (the gradient `g`). -# -# As a result, the parameters are mutated and the optimizer's internal state may change. -# """ -# function update!(opt, x, x̄, st) -# Δ, st = x .- apply(opt, x, x̄, st) -# update!(x, Δ) -# Δ, st -# end -# -# function update!(opt, xs::Params, gs, st) -# for x in xs -# gs[x] == nothing && continue -# _, st = update!(opt, x, gs[x], st) -# end -# end +""" + update!(x, x̄) +Update the array `x` according to `x .-= x̄`. +""" +function update!(x::AbstractArray, x̄) + x .-= x̄ +end + +""" + update!(opt, p, g) + update!(opt, ps::Params, gs) +Perform an update step of the parameters `ps` (or the single parameter `p`) +according to optimizer `opt` and the gradients `gs` (the gradient `g`). +As a result, the parameters are mutated and the optimizer's internal state may change. +""" +function update!(opt, x, x̄) + x .-= apply!(opt, x, x̄) +end + +function update!(opt, xs::Params, gs) + for x in xs + gs[x] == nothing && continue + update!(opt, x, gs[x]) + end +end # Callback niceties call(f, xs...) = f(xs...) runall(f) = f -runall(fs::AbstractVector) = (x...) -> foreach(f -> call(f, x...), fs) +runall(fs::AbstractVector) = () -> foreach(call, fs) struct SkipException <: Exception end """ skip() - Call `Flux.skip()` in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient. - # Examples ```julia cb = function () @@ -61,10 +54,8 @@ struct StopException <: Exception end """ stop() - Call `Flux.stop()` in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit. - # Examples ```julia cb = function () @@ -81,69 +72,25 @@ batchmemaybe(x::Tuple) = x """ train!(loss, params, data, opt; cb) - For each datapoint `d` in `data`, compute the gradient of `loss` with respect to `params` through backpropagation and call the optimizer `opt`. - If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`. - A callback is given with the keyword argument `cb`. For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)): - train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) - The callback can call [`Flux.stop`](@ref) to interrupt the training loop. - Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -# function train!(loss, ps, data, opt; cb = () -> ()) -# ps = Params(ps) -# cb = runall(cb) -# st = Optimisers.init(opt, ps) -# @progress for d in data -# try -# gs = gradient(ps) do -# loss(batchmemaybe(d)...) -# end -# update!(opt, ps, gs) -# cb() -# catch ex -# if ex isa StopException -# break -# elseif ex isa SkipException -# continue -# else -# rethrow(ex) -# end -# end -# end -# end - -function train(m, loss, data, opt; cb = (x...) -> (), - prehook = (x...) -> (), - posthook = (x...) -> ()) - st = [Optimisers.init(opt, p) for p in Flux.params(m)] - prehook = runall(prehook) - posthook = runall(posthook) +function train!(loss, ps, data, opt; cb = () -> ()) + ps = Params(ps) cb = runall(cb) - - dlen = try - length(data) - catch e - @warn "Dataset length unkown" - 0 - end - - for d in data + @progress for d in data try - ŷ, back = pullback(m) do m - loss(m, batchmemaybe(d)...) + gs = gradient(ps) do + loss(batchmemaybe(d)...) end - prehook(ŷ) - m̂, = back(sensitivity(ŷ)) - posthook(ŷ, m, m̂) - m, st = opt(m, m̂, st) - cb(ŷ, m, m̂) + update!(opt, ps, gs) + cb() catch ex if ex isa StopException break @@ -154,16 +101,12 @@ function train(m, loss, data, opt; cb = (x...) -> (), end end end - - m end """ @epochs N body - Run `body` `N` times. Mainly useful for quickly doing multiple epochs of training in a REPL. - # Examples ```jldoctest julia> Flux.@epochs 2 println("hello") @@ -179,11 +122,3 @@ macro epochs(n, ex) $(esc(ex)) end) end - -# macro epochs(n, f, ex) -# :(@progress for i = 1:$(esc(n)) -# @info "Epoch $i" -# $(esc(ex)) -# $(esc(f($i))) -# end) -# end From 4c5d6697ffa9ef3c7e396f0290b90e119d22542e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 8 Feb 2021 21:32:24 +0530 Subject: [PATCH 13/25] allow explicit state --- src/optimise/train.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 7b1708eac6..8be2567724 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -16,14 +16,16 @@ Perform an update step of the parameters `ps` (or the single parameter `p`) according to optimizer `opt` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. """ -function update!(opt, x, x̄) - x .-= apply!(opt, x, x̄) +function update!(opt, x, x̄, st) + x̄, st = apply(opt, x, x̄, st) + x .-= x̄ + x̄, st end -function update!(opt, xs::Params, gs) +function update!(opt, xs::Params, gs, st) for x in xs gs[x] == nothing && continue - update!(opt, x, gs[x]) + _, st = update!(opt, x, gs[x], st) end end @@ -84,12 +86,13 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) + st = [Optimisers.init(opt, p) for p in ps] @progress for d in data try gs = gradient(ps) do loss(batchmemaybe(d)...) end - update!(opt, ps, gs) + _, st = update!(opt, ps, gs, st) cb() catch ex if ex isa StopException From 319d7c23c2a5d11ab1679ce0a2e08e9f8243d93b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 8 Feb 2021 21:35:21 +0530 Subject: [PATCH 14/25] git fixes --- src/optimise/train.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 8be2567724..cf7aa7300c 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -3,6 +3,7 @@ import Zygote: Params, gradient """ update!(x, x̄) + Update the array `x` according to `x .-= x̄`. """ function update!(x::AbstractArray, x̄) @@ -12,8 +13,10 @@ end """ update!(opt, p, g) update!(opt, ps::Params, gs) + Perform an update step of the parameters `ps` (or the single parameter `p`) according to optimizer `opt` and the gradients `gs` (the gradient `g`). + As a result, the parameters are mutated and the optimizer's internal state may change. """ function update!(opt, x, x̄, st) @@ -38,8 +41,10 @@ struct SkipException <: Exception end """ skip() + Call `Flux.skip()` in a callback to indicate when a callback condition is met. This will trigger the train loop to skip the current data point and not update with the calculated gradient. + # Examples ```julia cb = function () @@ -56,8 +61,10 @@ struct StopException <: Exception end """ stop() + Call `Flux.stop()` in a callback to indicate when a callback condition is met. This will trigger the train loop to stop and exit. + # Examples ```julia cb = function () @@ -74,13 +81,19 @@ batchmemaybe(x::Tuple) = x """ train!(loss, params, data, opt; cb) + For each datapoint `d` in `data`, compute the gradient of `loss` with respect to `params` through backpropagation and call the optimizer `opt`. + If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`. + A callback is given with the keyword argument `cb`. For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)): + train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) + The callback can call [`Flux.stop`](@ref) to interrupt the training loop. + Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ function train!(loss, ps, data, opt; cb = () -> ()) @@ -108,8 +121,10 @@ end """ @epochs N body + Run `body` `N` times. Mainly useful for quickly doing multiple epochs of training in a REPL. + # Examples ```jldoctest julia> Flux.@epochs 2 println("hello") From 28e6a7eb11e513c0e492e725adec0013e9cf6ed7 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 9 Feb 2021 20:30:21 +0530 Subject: [PATCH 15/25] pkg up + compat --- Manifest.toml | 72 +++++++++++++++++++++++---------------------------- Project.toml | 2 +- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 4548c66d3f..6bf05007af 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -14,9 +14,9 @@ version = "0.3.3" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "27edd95a09fd428113ca019c092e8aeca2eb1f2d" +git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.0.0" +version = "3.2.0" [[Artifacts]] deps = ["Pkg"] @@ -40,21 +40,21 @@ version = "0.4.1" [[CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "39f6f584bec264ace76f924d1c8637c85617697e" +git-tree-sha1 = "6ccc73b2d8b671f7a65c92b5f08f81422ebb7547" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "2.4.0" +version = "2.4.1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "31b28f5123afa5e5ca0c885e4051896032754578" +git-tree-sha1 = "56bbb956a573ac16b277008edb1762ef80076e78" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.45" +version = "0.7.50" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] -git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a" +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "d3d0a4e0d5bc03a6c97f4d249c8a471fc20a2f33" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.24" +version = "0.9.28" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -93,15 +93,15 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "0.3.4+0" [[DataAPI]] -git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84" +git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.4.0" +version = "1.5.1" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "fb0aa371da91c1ff9dc7fbed6122d3e411420b9c" +git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.8" +version = "0.18.9" [[Dates]] deps = ["Printf"] @@ -134,9 +134,9 @@ version = "0.1.3" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "ff537e5a3cba92fb48f30fec46723510450f2c0e" +git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.10.2" +version = "0.11.1" [[FixedPointNumbers]] deps = ["Statistics"] @@ -146,17 +146,17 @@ version = "0.8.4" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4" +git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.14" +version = "0.10.16" [[Functors]] deps = ["MacroTools"] -git-tree-sha1 = "5b082c8f4507e15d435f7727efc9d7006f954140" +git-tree-sha1 = "60dc8972fec14145524caf17edfee222ab531e37" repo-rev = "dg/grad" repo-url = "https://github.com/FluxML/Functors.jl.git" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.1.0" +version = "0.2.0" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] @@ -193,9 +193,9 @@ version = "0.8.4" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "d0d99629d6ae4a3e211ae83d8870907bd842c811" +git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.5.2" +version = "3.6.0" [[LibGit2]] deps = ["Printf"] @@ -229,23 +229,18 @@ version = "0.5.0" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" +git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" +version = "0.4.5" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - [[NNlib]] -deps = ["ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "13fd29731c7f609cb82a3a544c5538584d22c153" +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.11" +version = "0.7.14" [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" @@ -260,16 +255,16 @@ version = "0.5.3+4" [[Optimisers]] deps = ["Functors", "Random", "Requires", "Statistics"] -git-tree-sha1 = "3130cebce66aed0943bdd8f3d61c641e44d7ea0e" +git-tree-sha1 = "0a9f8b051708c9a41ead0810cf6aff9af2edb8bc" repo-rev = "master" repo-url = "https://github.com/FluxML/Optimisers.jl.git" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" version = "0.1.0" [[OrderedCollections]] -git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" +git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.2" +version = "1.3.3" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -292,10 +287,9 @@ deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" +version = "1.0.0" [[Requires]] deps = ["UUIDs"] @@ -391,9 +385,9 @@ version = "1.2.11+18" [[Zygote]] deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "d88a7e71fc2eef9510187b1c7d4af7a5177633d0" +git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.0" +version = "0.6.3" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/Project.toml b/Project.toml index e313e4e40f..8e5b56316f 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ Adapt = "2.0, 3.0" CUDA = "2.1" CodecZlib = "0.7" Colors = "0.12" -Functors = "0.1" +Functors = "0.1, 0.2" Juno = "0.8" MacroTools = "0.5" NNlib = "0.7.10" From 857f4dbc31bb0e9a51383ea5796402c68f79aa33 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 9 Feb 2021 20:32:47 +0530 Subject: [PATCH 16/25] rm import --- src/optimise/Optimise.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 16404a6507..6594d7be84 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -8,7 +8,6 @@ export train!, InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, ClipValue, ClipNorm -include("optimisers.jl") include("train.jl") end From 06c659966f5880daf1a68532b14ff959164fd805 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 9 Feb 2021 20:40:58 +0530 Subject: [PATCH 17/25] fix exports --- src/optimise/Optimise.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 6594d7be84..1e503378fb 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,10 +3,10 @@ module Optimise using LinearAlgebra using Optimisers -export train!, +export train!, skip, stop + Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, - ClipValue, ClipNorm + InvDecay, ExpDecay, WeightDecay, stop, skip, ChainOptimiser include("train.jl") From 3012f877edbcae79160ed3749b420ddecd2a71a5 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Tue, 9 Feb 2021 20:56:42 +0530 Subject: [PATCH 18/25] updates to tests --- src/optimise/Optimise.jl | 5 +++-- test/optimise.jl | 30 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 1e503378fb..c0d56dc20e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -2,11 +2,12 @@ module Optimise using LinearAlgebra using Optimisers +using Optimisers: apply -export train!, skip, stop +export train!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief, - InvDecay, ExpDecay, WeightDecay, stop, skip, ChainOptimiser + WeightDecay, stop, skip, ChainOptimiser include("train.jl") diff --git a/test/optimise.jl b/test/optimise.jl index 40db490f57..4069fb4814 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -9,21 +9,21 @@ using Random # so that w and w' are different Random.seed!(84) w = randn(10, 10) - # @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), - # NADAM(), RADAM(), Descent(), ADAM(), OADAM(), AdaBelief(), - # Nesterov(), RMSProp(), Momentum()] - # Random.seed!(42) - # w′ = randn(10, 10) - # b = Flux.Zeros() - # loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - # for t = 1: 10^5 - # θ = params([w′, b]) - # x = rand(10) - # θ̄ = gradient(() -> loss(x), θ) - # Optimise.update!(opt, θ, θ̄) - # end - # @test loss(rand(10, 10)) < 0.01 - # end + @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(), AMSGrad(), + NADAM(), RADAM(), Descent(), ADAM(), OADAM(), AdaBelief(), + Nesterov(), RMSProp(), Momentum()] + Random.seed!(42) + w′ = randn(10, 10) + b = Flux.Zeros() + loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + for t = 1: 10^5 + θ = params([w′, b]) + x = rand(10) + θ̄ = gradient(() -> loss(x), θ) + Flux.Optimise.update!(opt, θ, θ̄) + end + @test loss(rand(10, 10)) < 0.01 + end Random.seed!(84) w′ = rand(3,3) From 952f65c0f1fdc863a7edaf9e1b10300f4654e830 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 10 Feb 2021 15:58:30 +0530 Subject: [PATCH 19/25] dirty name hack --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 4069fb4814..d31ad70e13 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -30,7 +30,7 @@ using Random @testset for o in (Descent(0.1), Momentum(0.01, 0.9), Nesterov(0.001, 0.9), RMSProp(0.001, 0.9), ADAM(0.001, (0.9, 0.99))) w = rand(3,3) - st = Optimisers.init(o,w) + st = Flux.Optimise.Optimisers.init(o,w) loss(x, y) = mean((x .- y) .^ 2) l = loss(w, w′) for i = 1:10^4 From 6185e2e3bd30bd04f2f6471a255fc99ef6547d28 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 10 Feb 2021 18:33:28 +0530 Subject: [PATCH 20/25] add return state from update step --- src/optimise/train.jl | 7 +++++-- test/optimise.jl | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index cf7aa7300c..ffcb3562d6 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -26,10 +26,13 @@ function update!(opt, x, x̄, st) end function update!(opt, xs::Params, gs, st) - for x in xs + st_ = [] + for (x,s) in zip(xs,st) gs[x] == nothing && continue - _, st = update!(opt, x, gs[x], st) + _, s = update!(opt, x, gs[x], s) + append!(st_, s) end + nothing, st_ end # Callback niceties diff --git a/test/optimise.jl b/test/optimise.jl index d31ad70e13..278c1b898f 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -16,11 +16,12 @@ using Random w′ = randn(10, 10) b = Flux.Zeros() loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) + st = [Flux.Optimisers.init(opt, p) for p in [w′, b]] for t = 1: 10^5 θ = params([w′, b]) x = rand(10) θ̄ = gradient(() -> loss(x), θ) - Flux.Optimise.update!(opt, θ, θ̄) + _, st = Flux.Optimise.update!(opt, θ, θ̄, st) end @test loss(rand(10, 10)) < 0.01 end From a56d2d30a73678fa735a4e6dffca2814fa7fbf06 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 10 Feb 2021 18:35:34 +0530 Subject: [PATCH 21/25] disable some currently unsupported cases --- test/optimise.jl | 112 +++++++++++++++++++++++------------------------ 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/test/optimise.jl b/test/optimise.jl index 278c1b898f..0da580ccde 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -102,59 +102,59 @@ end Flux.train!(loss, Flux.params(r), (r,), Descent()) end -@testset "ExpDecay" begin - - @testset "Sanity Check" begin - o = ExpDecay(0.2, 0.5, 1, 1e-3) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - w = randn(10, 10) - o = ExpDecay(0.1, 0.1, 1000, 1e-4) - w1 = randn(10,10) - loss(x) = Flux.Losses.mse(w*x, w1*x) - flag = 1 - decay_steps = [] - for t = 1:10^5 - prev_eta = o.eta - θ = Params([w1]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - prev_grad = collect(θ̄[w1]) - delta = Optimise.apply!(o, w1, θ̄[w1]) - w1 .-= delta - new_eta = o.eta - if new_eta != prev_eta - push!(decay_steps, t) - end - array = fill(o.eta, size(prev_grad)) - if array .* prev_grad != delta - flag = 0 - end - end - @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches - # clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). - ground_truth = [] - for i in 1:4 - push!(ground_truth, 1000*i) # Expected decay steps for this example. - end - @test decay_steps == ground_truth - @test o.eta == o.clip -end - -@testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 -end +# @testset "ExpDecay" begin +# +# @testset "Sanity Check" begin +# o = ExpDecay(0.2, 0.5, 1, 1e-3) +# p = [0.0] +# steps = 1:8 +# eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) +# eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] +# @test eta_actual == eta_expected +# end +# +# w = randn(10, 10) +# o = ExpDecay(0.1, 0.1, 1000, 1e-4) +# w1 = randn(10,10) +# loss(x) = Flux.Losses.mse(w*x, w1*x) +# flag = 1 +# decay_steps = [] +# for t = 1:10^5 +# prev_eta = o.eta +# θ = Params([w1]) +# x = rand(10) +# θ̄ = gradient(() -> loss(x), θ) +# prev_grad = collect(θ̄[w1]) +# delta = Optimise.apply!(o, w1, θ̄[w1]) +# w1 .-= delta +# new_eta = o.eta +# if new_eta != prev_eta +# push!(decay_steps, t) +# end +# array = fill(o.eta, size(prev_grad)) +# if array .* prev_grad != delta +# flag = 0 +# end +# end +# @test flag == 1 +# # Test to check if decay happens at decay steps. Eta reaches +# # clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). +# ground_truth = [] +# for i in 1:4 +# push!(ground_truth, 1000*i) # Expected decay steps for this example. +# end +# @test decay_steps == ground_truth +# @test o.eta == o.clip +# end +# +# @testset "Clipping" begin +# w = randn(10, 10) +# loss(x) = sum(w * x) +# θ = Params([w]) +# x = 1000 * randn(10) +# w̄ = gradient(() -> loss(x), θ)[w] +# w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) +# @test all(w̄_value .<= 1) +# w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) +# @test norm(w̄_norm) <= 1 +# end From 447b8a4a7f5c52b35d01a221ed9b84ddb103310a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Feb 2021 19:34:10 +0530 Subject: [PATCH 22/25] add Optimisers to imports --- test/optimise.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/optimise.jl b/test/optimise.jl index 0da580ccde..57daac2a3a 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,6 +1,7 @@ using Flux.Optimise using Flux.Optimise: runall using Flux: Params, gradient +using Flux.Optimise: Optimisers using Test using Random From a41c6e1372d921d0a916bded80d30c80a06bb26d Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Feb 2021 19:55:13 +0530 Subject: [PATCH 23/25] qualify optimisers --- test/optimise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimise.jl b/test/optimise.jl index 57daac2a3a..3041009a4e 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -17,7 +17,7 @@ using Random w′ = randn(10, 10) b = Flux.Zeros() loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - st = [Flux.Optimisers.init(opt, p) for p in [w′, b]] + st = [Optimisers.init(opt, p) for p in [w′, b]] for t = 1: 10^5 θ = params([w′, b]) x = rand(10) From da31efd712230374d0c56b975acafea34fd75f3b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 24 Feb 2021 20:08:24 +0530 Subject: [PATCH 24/25] define init(o, params) --- src/optimise/train.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index ffcb3562d6..5d148045f1 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -79,6 +79,8 @@ function stop() throw(StopException()) end +Optimisers.init(o, ps::Params) = [Optimisers.init(o,p) for p in ps] + batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x @@ -102,7 +104,7 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, ps, data, opt; cb = () -> ()) ps = Params(ps) cb = runall(cb) - st = [Optimisers.init(opt, p) for p in ps] + st = Optimisers.init(opt, ps) @progress for d in data try gs = gradient(ps) do From ff17b0b6d9ed41b61b582ab0d8cfedaf588daba3 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 16 Jan 2022 21:37:58 +0530 Subject: [PATCH 25/25] Apply suggestions from code review Co-authored-by: Kyle Daruwalla --- docs/src/training/optimisers.md | 2 +- src/optimise/train.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 05947da26a..21fe07e2cd 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -98,7 +98,7 @@ v = ρ * v - η * Δ w = w - v ``` -The `apply` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and returns the new state of the optimizer. +The `apply` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients and optimizer state. Flux internally calls on this function via the `update!` function. It shares the API with `apply` but ensures that multiple parameters are handled gracefully. diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 5d148045f1..9db9e30620 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -32,7 +32,7 @@ function update!(opt, xs::Params, gs, st) _, s = update!(opt, x, gs[x], s) append!(st_, s) end - nothing, st_ + xs, st_ end # Callback niceties