Skip to content

Commit

Permalink
Update PPO to support net_arch, and additional fixes (#65)
Browse files Browse the repository at this point in the history
* Add support for flexible arch in PPO

* Fix ent_coeff logging for TQC

* Fix name order

* Fix ent_coeff logging for SAC

* Hotfix for PPO, do not squash output at test time

* Fix typo

* Fix typo in common policy

* Try Gaussian dist for TQC

* Revert "Try Gaussian dist for TQC"

This reverts commit 6eeaf23.

* Fix CrossQ ent_coef logging

* Log PPO std when possible

* Fix for CrossQ
  • Loading branch information
araffin authored Feb 14, 2025
1 parent 9cad1d0 commit 8238fcc
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 42 deletions.
8 changes: 4 additions & 4 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(self, *args, **kwargs):

@staticmethod
@jax.jit
def sample_action(actor_state, obervations, key):
dist = actor_state.apply_fn(actor_state.params, obervations)
def sample_action(actor_state, observations, key):
dist = actor_state.apply_fn(actor_state.params, observations)
action = dist.sample(seed=key)
return action

@staticmethod
@jax.jit
def select_action(actor_state, obervations):
return actor_state.apply_fn(actor_state.params, obervations).mode()
def select_action(actor_state, observations):
return actor_state.apply_fn(actor_state.params, observations).mode()

@no_type_check
def predict(
Expand Down
18 changes: 15 additions & 3 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.target_entropy,
Expand All @@ -236,6 +236,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -421,6 +422,7 @@ def _train(
"actor_loss": jnp.array(0.0),
"qf_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -468,7 +470,12 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
target_entropy,
key,
)
info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value}
info = {
"actor_loss": actor_loss_value,
"qf_loss": qf_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
"actor_state": actor_state,
Expand All @@ -485,5 +492,10 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["actor_state"],
update_carry["ent_coef_state"],
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
(
update_carry["info"]["actor_loss"],
update_carry["info"]["qf_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
8 changes: 4 additions & 4 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,21 @@ def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:

@staticmethod
@jax.jit
def sample_action(actor_state, obervations, key):
def sample_action(actor_state, observations, key):
dist = actor_state.apply_fn(
{"params": actor_state.params, "batch_stats": actor_state.batch_stats},
obervations,
observations,
train=False,
)
action = dist.sample(seed=key)
return action

@staticmethod
@jax.jit
def select_action(actor_state, obervations):
def select_action(actor_state, observations):
return actor_state.apply_fn(
{"params": actor_state.params, "batch_stats": actor_state.batch_stats},
obervations,
observations,
train=False,
).mode()

Expand Down
40 changes: 21 additions & 19 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@


class Critic(nn.Module):
n_units: int = 256
net_arch: Sequence[int]
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

x = nn.Dense(1)(x)
return x


class Actor(nn.Module):
action_dim: int
n_units: int = 256
net_arch: Sequence[int]
log_std_init: float = 0.0
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
Expand All @@ -60,10 +60,11 @@ def __post_init__(self) -> None:
@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)

for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

action_logits = nn.Dense(self.action_dim)(x)
if self.num_discrete_choices is None:
# Continuous actions
Expand Down Expand Up @@ -131,18 +132,19 @@ def __init__(
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
squash_output=False,
)
self.log_std_init = log_std_init
self.activation_fn = activation_fn
if net_arch is not None:
if isinstance(net_arch, list):
self.n_units = net_arch[0]
self.net_arch_pi = self.net_arch_vf = net_arch
else:
assert isinstance(net_arch, dict)
self.n_units = net_arch["pi"][0]
self.net_arch_pi = net_arch["pi"]
self.net_arch_vf = net_arch["vf"]
else:
self.n_units = 64
self.net_arch_pi = self.net_arch_vf = [64, 64]
self.use_sde = use_sde

self.key = self.noise_key = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -188,7 +190,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
raise NotImplementedError(f"{self.action_space}")

self.actor = Actor(
n_units=self.n_units,
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
**actor_kwargs, # type: ignore[arg-type]
Expand All @@ -208,7 +210,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
),
)

self.vf = Critic(n_units=self.n_units, activation_fn=self.activation_fn)
self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
Expand Down Expand Up @@ -249,9 +251,9 @@ def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray:

@staticmethod
@jax.jit
def _predict_all(actor_state, vf_state, obervations, key):
dist = actor_state.apply_fn(actor_state.params, obervations)
def _predict_all(actor_state, vf_state, observations, key):
dist = actor_state.apply_fn(actor_state.params, observations)
actions = dist.sample(seed=key)
log_probs = dist.log_prob(actions)
values = vf_state.apply_fn(vf_state.params, obervations).flatten()
values = vf_state.apply_fn(vf_state.params, observations).flatten()
return actions, log_probs, values
8 changes: 5 additions & 3 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,11 @@ def train(self) -> None:
# self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/pg_loss", pg_loss.item())
self.logger.record("train/explained_variance", explained_var)
# if hasattr(self.policy, "log_std"):
# self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

try:
log_std = self.policy.actor_state.params["params"]["log_std"]
self.logger.record("train/std", np.exp(log_std).mean().item())
except KeyError:
pass
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
# if self.clip_range_vf is not None:
Expand Down
18 changes: 15 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -238,6 +238,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -391,6 +392,7 @@ def _train(
"actor_loss": jnp.array(0.0),
"qf_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -438,7 +440,12 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
target_entropy,
key,
)
info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value}
info = {
"actor_loss": actor_loss_value,
"qf_loss": qf_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
"actor_state": actor_state,
Expand All @@ -455,5 +462,10 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["actor_state"],
update_carry["ent_coef_state"],
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
(
update_carry["info"]["actor_loss"],
update_carry["info"]["qf_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
4 changes: 2 additions & 2 deletions sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def forward(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:

@staticmethod
@jax.jit
def select_action(actor_state, obervations) -> np.ndarray:
return actor_state.apply_fn(actor_state.params, obervations)
def select_action(actor_state, observations) -> np.ndarray:
return actor_state.apply_fn(actor_state.params, observations)

def _predict(self, observation: np.ndarray, deterministic: bool = True) -> np.ndarray: # type: ignore[override]
# TD3 is always deterministic
Expand Down
6 changes: 5 additions & 1 deletion sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value),
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -244,6 +244,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf1_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -455,6 +456,7 @@ def _train(
"qf1_loss": jnp.array(0.0),
"qf2_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -518,6 +520,7 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
"qf1_loss": qf1_loss_value,
"qf2_loss": qf2_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
Expand All @@ -542,5 +545,6 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["info"]["qf2_loss"],
update_carry["info"]["actor_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.19.0
0.20.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.4.0,<3.0",
"jax>=0.4.12",
"stable_baselines3>=2.5.0,<3.0",
"jax>=0.4.24",
"jaxlib",
"flax",
"optax",
Expand Down

0 comments on commit 8238fcc

Please # to comment.