Skip to content

Commit

Permalink
Add save_freq option to CAPQL train method (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre authored Oct 19, 2024
1 parent 76147d3 commit 8d2f89b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions morl_baselines/multi_policy/capql/capql.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def train(
eval_freq: int = 10000,
reset_num_timesteps: bool = False,
checkpoints: bool = False,
save_freq: int = 10000,
):
"""Train the agent.
Expand All @@ -403,6 +404,7 @@ def train(
eval_freq (int): Number of timesteps between evaluations during an iteration.
reset_num_timesteps (bool): Whether to reset the number of timesteps.
checkpoints (bool): Whether to save checkpoints.
save_freq (int): Number of timesteps between checkpoints.
"""
if self.log:
self.register_additional_config(
Expand Down Expand Up @@ -476,7 +478,7 @@ def train(
)

# Checkpoint
if checkpoints:
self.save(filename="CAPQL", save_replay_buffer=False)
if checkpoints and self.global_step % save_freq == 0:
self.save(filename=f"CAPQL step={self.global_step}", save_replay_buffer=False)

self.close_wandb()

0 comments on commit 8d2f89b

Please # to comment.