diff --git a/Project.toml b/Project.toml index a6d6f83..14ea8ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Stuffing" uuid = "4175e07e-e5b7-423e-8796-3ea7f6d48281" authors = ["guoyongzhi "] -version = "0.4.0" +version = "0.4.1" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/train.jl b/src/train.jl index 5166693..1ee0815 100644 --- a/src/train.jl +++ b/src/train.jl @@ -469,7 +469,7 @@ end function train!(ts, nepoch::Number=-1, args...; trainer=trainepoch_EM2!, patient::Number=trainer(:patient), optimiser=Momentum(η=1/4, ρ=0.5), - callbackstep=1, callbackfun=x->x, teleporting=i->true, kargs...) + callbackstep=1, callbackfun=x->x, teleporting=i->true, resource=trainer(inputs=ts), kargs...) teleporton = true if teleporting isa Function on = teleporting @@ -495,7 +495,6 @@ function train!(ts, nepoch::Number=-1, args...; nc_min_g = typemax(Int) teleport_count = 0. last_cinds = nothing - resource = trainer(inputs=ts) collpool = nothing if :collpool in keys(resource) collpool = resource[:collpool]