Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Use Optimisers.jl #1481

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ version = "0.3.3"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "27edd95a09fd428113ca019c092e8aeca2eb1f2d"
git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.0.0"
version = "3.2.0"

[[Artifacts]]
deps = ["Pkg"]
Expand All @@ -40,21 +40,21 @@ version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
git-tree-sha1 = "39f6f584bec264ace76f924d1c8637c85617697e"
git-tree-sha1 = "6ccc73b2d8b671f7a65c92b5f08f81422ebb7547"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "2.4.0"
version = "2.4.1"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "31b28f5123afa5e5ca0c885e4051896032754578"
git-tree-sha1 = "56bbb956a573ac16b277008edb1762ef80076e78"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.45"
version = "0.7.50"

[[ChainRulesCore]]
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a"
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "d3d0a4e0d5bc03a6c97f4d249c8a471fc20a2f33"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.24"
version = "0.9.28"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand Down Expand Up @@ -93,15 +93,15 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84"
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.4.0"
version = "1.5.1"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "fb0aa371da91c1ff9dc7fbed6122d3e411420b9c"
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.8"
version = "0.18.9"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -134,9 +134,9 @@ version = "0.1.3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "ff537e5a3cba92fb48f30fec46723510450f2c0e"
git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.10.2"
version = "0.11.1"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand All @@ -146,15 +146,17 @@ version = "0.8.4"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4"
git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.14"
version = "0.10.16"

[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "f40adc6422f548176bb4351ebd29e4abf773040a"
git-tree-sha1 = "60dc8972fec14145524caf17edfee222ab531e37"
repo-rev = "dg/grad"
repo-url = "https://github.com/FluxML/Functors.jl.git"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.1.0"
version = "0.2.0"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
Expand Down Expand Up @@ -191,9 +193,9 @@ version = "0.8.4"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "d0d99629d6ae4a3e211ae83d8870907bd842c811"
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.5.2"
version = "3.6.0"

[[LibGit2]]
deps = ["Printf"]
Expand Down Expand Up @@ -227,23 +229,18 @@ version = "0.5.0"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8"
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.4"
version = "0.4.5"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MuladdMacro]]
git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68"
uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
version = "0.2.2"

[[NNlib]]
deps = ["ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "13fd29731c7f609cb82a3a544c5538584d22c153"
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.11"
version = "0.7.14"

[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
Expand All @@ -256,10 +253,18 @@ git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[Optimisers]]
deps = ["Functors", "Random", "Requires", "Statistics"]
git-tree-sha1 = "0a9f8b051708c9a41ead0810cf6aff9af2edb8bc"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Optimisers.jl.git"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.1.0"

[[OrderedCollections]]
git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.3.2"
version = "1.3.3"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand All @@ -282,10 +287,9 @@ deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Reexport]]
deps = ["Pkg"]
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "0.2.0"
version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
Expand Down Expand Up @@ -381,9 +385,9 @@ version = "1.2.11+18"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "d88a7e71fc2eef9510187b1c7d4af7a5177633d0"
git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.0"
version = "0.6.3"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -31,7 +32,7 @@ Adapt = "2.0, 3.0"
CUDA = "2.1"
CodecZlib = "0.7"
Colors = "0.12"
Functors = "0.1"
Functors = "0.1, 0.2"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.10"
Expand Down
26 changes: 14 additions & 12 deletions docs/src/training/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Running this will alter the parameters `W` and `b` and our loss should go down.
opt = Descent(0.1) # Gradient descent with learning rate 0.1

for p in (W, b)
update!(opt, p, grads[p])
opt(p, grads[p])
end
```

Expand Down Expand Up @@ -70,23 +70,24 @@ Flux's optimisers are built around a `struct` that holds all the optimiser param
In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example.

```julia
mutable struct Momentum
struct Momentum
eta
rho
velocity
end

Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict())
Momentum(eta::Real, rho::Real) = Momentum(eta, rho)
Optimisers.init(opt::Momentum, x) = (zero(x),)
```

The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked.

```julia
function Flux.Optimise.apply!(o::Momentum, x, Δ)
function Flux.Optimise.apply(o::Momentum, x, Δ, st)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@. Δ = -v
v, = st
v = @. ρ * v - η * Δ
Δ = @. -v
Δ, (v,)
end
```

Expand All @@ -97,9 +98,9 @@ v = ρ * v - η * Δ
w = w - v
```

The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.
The `apply` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients and optimizer state.

Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.
Flux internally calls on this function via the `update!` function. It shares the API with `apply` but ensures that multiple parameters are handled gracefully.

## Composing Optimisers

Expand All @@ -121,11 +122,12 @@ ps = Params([w, w1])
loss(x) = Flux.Losses.mse(w * x, w1 * x)

loss(rand(10)) # around 9
st = Optimisers.init(opt, [w, w1])

for t = 1:10^5
θ = Params([w, w1])
θ = ps
θ̄ = gradient(() -> loss(rand(10)), θ)
Flux.Optimise.update!(opt, θ, θ̄)
ps, st = Flux.Optimise.update!(opt, θ, θ̄, st)
end

loss(rand(10)) # around 0.9
Expand Down
12 changes: 6 additions & 6 deletions src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
module Optimise

using LinearAlgebra
using Optimisers
using Optimisers: apply

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
ClipValue, ClipNorm
export train!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
WeightDecay, stop, skip, ChainOptimiser

include("optimisers.jl")
include("train.jl")

end
Loading