diff --git a/src/TBLogger.jl b/src/TBLogger.jl index a3cf9cd..cee5087 100644 --- a/src/TBLogger.jl +++ b/src/TBLogger.jl @@ -5,6 +5,16 @@ mutable struct TBLogger{P,S} <: AbstractLogger global_step::Int step_increment::Int min_level::LogLevel + + function TBLogger{P,S}(logdir::P, + file::S, + all_files::Dict{String, S}, + global_step::Int, + step_increment::Int, + min_level::LogLevel) where {P,S} + lg = new{P, S}(logdir, file, all_files, global_step, step_increment, min_level) + return Base.finalizer(Base.close, lg) + end end @@ -195,6 +205,18 @@ Returns the internal step counter of the logger. """ step(lg::TBLogger) = lg.global_step +""" + close(lg) + +Close the TBLogger `lg`, releasing all file handles. +""" +function Base.close(lg::TBLogger) + # close open streams + for k=keys(lg.all_files) + close(lg.all_files[k]) + end +end + """ reset!(lg) diff --git a/test/test_TBLogger.jl b/test/test_TBLogger.jl index 67f068d..9992dae 100644 --- a/test/test_TBLogger.jl +++ b/test/test_TBLogger.jl @@ -89,6 +89,31 @@ end close.(values(tbl.all_files)) end +@testset "closing" begin + tbl = TBLogger(test_log_dir*"run", tb_overwrite) + TensorBoardLogger.add_eventfile(tbl, "pp") + files = keys(tbl.all_files) + + close(tbl) + @test begin + foreach(f -> rm(joinpath(test_log_dir*"run", f)), files) + # rm will error if the file is still open + true + end + + tbl = TBLogger(test_log_dir*"run", tb_overwrite) + TensorBoardLogger.add_eventfile(tbl, "pp") + files = keys(tbl.all_files) + + tbl = nothing + Base.finalize(tbl) + @test begin + foreach(f -> rm(joinpath(test_log_dir*"run", f)), files) + # rm will error if the file is still open + true + end +end + @testset "resetting" begin tbl = TBLogger(test_log_dir*"run", tb_overwrite) TensorBoardLogger.add_eventfile(tbl, "pp")