From 8d2f89b4573c4ea7780855b631a5625117b481ad Mon Sep 17 00:00:00 2001 From: Lucas Alegre Date: Sat, 19 Oct 2024 12:37:16 -0300 Subject: [PATCH] Add save_freq option to CAPQL train method (#122) --- morl_baselines/multi_policy/capql/capql.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/morl_baselines/multi_policy/capql/capql.py b/morl_baselines/multi_policy/capql/capql.py index 1ae46bd..71f1b08 100644 --- a/morl_baselines/multi_policy/capql/capql.py +++ b/morl_baselines/multi_policy/capql/capql.py @@ -389,6 +389,7 @@ def train( eval_freq: int = 10000, reset_num_timesteps: bool = False, checkpoints: bool = False, + save_freq: int = 10000, ): """Train the agent. @@ -403,6 +404,7 @@ def train( eval_freq (int): Number of timesteps between evaluations during an iteration. reset_num_timesteps (bool): Whether to reset the number of timesteps. checkpoints (bool): Whether to save checkpoints. + save_freq (int): Number of timesteps between checkpoints. """ if self.log: self.register_additional_config( @@ -476,7 +478,7 @@ def train( ) # Checkpoint - if checkpoints: - self.save(filename="CAPQL", save_replay_buffer=False) + if checkpoints and self.global_step % save_freq == 0: + self.save(filename=f"CAPQL step={self.global_step}", save_replay_buffer=False) self.close_wandb()