Skip to content

Commit

Permalink
Add run tests for simba
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 20, 2024
1 parent 5d1342f commit bfd2531
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,26 @@ def test_tqc(tmp_path) -> None:
check_save_load(model, TQC, tmp_path)


@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ, "SimbaSAC"])
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ, "SimbaSAC", "SimbaCrossQ"])
def test_sac_td3(tmp_path, model_class) -> None:
policy = "MlpPolicy"
net_kwargs = {}
if model_class == "SimbaSAC":
model_class = SAC
policy = "SimbaPolicy"
net_kwargs = dict(net_arch=[64])
elif model_class == "SimbaCrossQ":
model_class = CrossQ
policy = "SimbaPolicy"
net_kwargs = dict(net_arch=[64])

model = model_class(
policy,
"Pendulum-v1",
verbose=1,
gradient_steps=1,
learning_rate=1e-3,
policy_kwargs=net_kwargs,
)
key_before_learn = model.key
model.learn(110)
Expand Down

0 comments on commit bfd2531

Please # to comment.