Skip to content

Commit 638ace1

Browse files
Merge pull request #9 from darsnack/darsnack/initial-impl
Move all optimizers to Optimisers.jl
2 parents 051270b + ce489fb commit 638ace1

File tree

4 files changed

+501
-77
lines changed

4 files changed

+501
-77
lines changed

src/Optimisers.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Functors: functor, fmap, isleaf
55
include("interface.jl")
66
include("rules.jl")
77

8-
export Descent, Momentum, Nesterov, RMSProp,
9-
ADAM
8+
export Descent, ADAM, Momentum, Nesterov, RMSProp,
9+
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
10+
WeightDecay, OptimiserChain
1011

1112
end # module

src/interface.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
function patch(x, x̄)
2-
return x .-
3-
end
1+
patch(x, x̄) = x .-
42

53
function state(o, x)
64
if isleaf(x)
75
return init(o, x)
86
else
97
x, _ = functor(x)
10-
map(x -> state(o, x), x)
8+
return map(x -> state(o, x), x)
119
end
1210
end
1311

@@ -20,11 +18,11 @@ function update(o, x::T, x̄, state) where T
2018
if=== nothing
2119
return x, state
2220
elseif isleaf(x)
23-
_update(o, x, x̄, state)
21+
return _update(o, x, x̄, state)
2422
else
2523
x̄, _ = functor(typeof(x), x̄)
26-
x, re = functor(typeof(x), x)
24+
x, restructure = functor(typeof(x), x)
2725
xstate = map((x, x̄, state) -> update(o, x, x̄, state), x, x̄, state)
28-
re(map(first, xstate)), map(x -> x[2], xstate)
26+
return restructure(map(first, xstate)), map(x -> x[2], xstate)
2927
end
3028
end

0 commit comments

Comments
 (0)