diff --git a/src/fit.jl b/src/fit.jl index 1a939d6..1cad485 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -453,7 +453,7 @@ end function train!(ts, nepoch::Number=-1, args...; trainer=trainepoch_EM2!, patient::Number=trainer(:patient), - optimiser=Momentum(η=0.3, ρ=0.5), scheduler=opt->eta!(opt, eta(opt)*0.8), + optimiser=Momentum(η=1/4, ρ=0.5), scheduler=lr->lr*√0.5, callback=x -> x, reposition=i -> true, resource=trainer(inputs=ts), kargs...) reposition_flag = true if reposition isa Function @@ -477,6 +477,7 @@ function train!(ts, nepoch::Number=-1, args...; indi_r = MonotoneIndicator{Int}() #for reposition indi_g = MonotoneIndicator{Int}() #for global patient indi_s = MonotoneIndicator{Int}() #for lr scheduler + eta_list = [] reposition_count = 0. last_repositioned = nothing colist = nothing @@ -524,9 +525,17 @@ function train!(ts, nepoch::Number=-1, args...; @info "The repositioning strategy failed after $ep epochs" break end - if indi_s.age > max(1, patient, nepoch / 200 * max(1, (length(ts) / indi_g.min))) - scheduler(optimiser) - @info "@epoch $ep(+$(indi_s.age)) η -> $(round(eta(optimiser), digits=3)) (current $nc collisions, best $(indi_s.min) collisions)" + if indi_s.age > max(1, patient, nepoch / 50) + if isempty(eta_list) || indi_s.min < eta_list[end][2] || (indi_s.min == eta_list[end][2] && rand()>0.5) + _eta = eta(optimiser) + push!(eta_list, (_eta, indi_s.min)) + eta!(optimiser, scheduler(_eta)) + @info "@epoch $ep(+$(indi_s.age)) η -> $(round(eta(optimiser), digits=3)) (current $nc collisions, best $(indi_s.min) collisions)" + else + last_eta, last_nc = pop!(eta_list) + eta!(optimiser, last_eta) + @info "@epoch $ep(+$(indi_s.age)) η <- $(round(eta(optimiser), digits=3)) (collisions $(indi_s.min) ≥ $(last_nc))" + end reset!(indi_s) end if indi_g.age > max(2, 2patient, nepoch / 50 * max(1, (length(ts) / indi_g.min)))