From 99ae33d8e1d022cde1a7a30d6568838e25d6847a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 27 Nov 2022 12:02:44 -0500 Subject: [PATCH 1/2] write => for OptimiserChain --- src/rules.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index ecc58609..0cd7ccb1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -636,10 +636,12 @@ function _norm(dx::Broadcast.Broadcasted, p::Real) end """ - OptimiserChain(opts...) + OptimiserChain(o1, o2, o34...) + o1 => o2 => o3 -Compose a sequence of optimisers so that each `opt` in `opts` +Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)` updates the gradient, in the order specified. +May be entered using `Pair` syntax with several `AbstractRule`s. With an empty sequence, `OptimiserChain()` is the identity, so `update!` will subtract the full gradient from the parameters. @@ -648,12 +650,13 @@ This is equivalent to `Descent(1)`. # Example ```jldoctest -julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1)); +julia> o = ClipGrad(1.0) => Descent(0.1) +OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)) julia> m = (zeros(3),); julia> s = Optimisers.setup(o, m) -(Leaf(OptimiserChain(ClipGrad(1.0), Descent(0.1)), (nothing, nothing)),) +(Leaf(ClipGrad(1.0) => Descent(0.1), (nothing, nothing)),) julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting ([-0.03, -0.1, -0.1],) @@ -664,6 +667,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule end OptimiserChain(opts...) = OptimiserChain(opts) +Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b) +Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...) + @functor OptimiserChain init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts) @@ -679,7 +685,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...) end end -function Base.show(io::IO, c::OptimiserChain) +function Base.show(io::IO, c::OptimiserChain) # compact show + if length(c.opts) > 1 + join(io, c.opts, " => ") + else + show(io, MIME"text/plain"(), c) + end +end +function Base.show(io::IO, ::MIME"text/plain", c::OptimiserChain) print(io, "OptimiserChain(") join(io, c.opts, ", ") print(io, ")") From 4a559755c461aa9c7df98414297c7ac033e7bcea Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:53:32 -0400 Subject: [PATCH 2/2] change from => to >> --- src/rules.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 0cd7ccb1..1cda82d9 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -637,11 +637,11 @@ end """ OptimiserChain(o1, o2, o34...) - o1 => o2 => o3 + o1 >> o2 >> o3 Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)` updates the gradient, in the order specified. -May be entered using `Pair` syntax with several `AbstractRule`s. +May be entered using the `>>` operator with several `AbstractRule`s. With an empty sequence, `OptimiserChain()` is the identity, so `update!` will subtract the full gradient from the parameters. @@ -650,8 +650,8 @@ This is equivalent to `Descent(1)`. # Example ```jldoctest -julia> o = ClipGrad(1.0) => Descent(0.1) -OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)) +julia> o = ClipGrad(1.0) >> Descent(0.1) +OptimiserChain(ClipGrad(1.0), Descent(0.1)) julia> m = (zeros(3),); @@ -667,8 +667,10 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule end OptimiserChain(opts...) = OptimiserChain(opts) -Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b) -Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...) +@doc @doc(OptimiserChain) +Base.:(>>)(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b) +Base.:(>>)(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...) +Base.:(>>)(ab::OptimiserChain, c::AbstractRule) = OptimiserChain(ab.opts..., c) @functor OptimiserChain @@ -687,7 +689,7 @@ end function Base.show(io::IO, c::OptimiserChain) # compact show if length(c.opts) > 1 - join(io, c.opts, " => ") + join(io, c.opts, " >> ") else show(io, MIME"text/plain"(), c) end