diff --git a/make_dmc.py b/make_dmc.py new file mode 100644 index 0000000..57374dd --- /dev/null +++ b/make_dmc.py @@ -0,0 +1,157 @@ +#adapted from https://github.com/imgeorgiev/dmc2gymnasium + +import logging +import numpy as np + +from dm_control import suite +from dm_env import specs +from gymnasium.core import Env +from gymnasium.spaces import Box +from gymnasium import spaces +from gymnasium.wrappers import FlattenObservation, RescaleAction + +def _spec_to_box(spec, dtype=np.float32): + def extract_min_max(s): + assert s.dtype == np.float64 or s.dtype == np.float32 + dim = int(np.prod(s.shape)) + if type(s) == specs.Array: + bound = np.inf * np.ones(dim, dtype=np.float32) + return -bound, bound + elif type(s) == specs.BoundedArray: + zeros = np.zeros(dim, dtype=np.float32) + return s.minimum + zeros, s.maximum + zeros + else: + logging.error("Unrecognized type") + mins, maxs = [], [] + for s in spec: + mn, mx = extract_min_max(s) + mins.append(mn) + maxs.append(mx) + low = np.concatenate(mins, axis=0).astype(dtype) + high = np.concatenate(maxs, axis=0).astype(dtype) + assert low.shape == high.shape + return Box(low, high, dtype=dtype) + + +def _flatten_obs(obs, dtype=np.float32): + obs_pieces = [] + for v in obs.values(): + flat = np.array([v]) if np.isscalar(v) else v.ravel() + obs_pieces.append(flat) + return np.concatenate(obs_pieces, axis=0).astype(dtype) + + +class DMCGym(Env): + def __init__( + self, + env_name, + task_kwargs={}, + environment_kwargs={}, + #rendering="egl", + render_height=64, + render_width=64, + render_camera_id=0, + action_repeat=1 + ): + domain = env_name.split('-')[0] + task = env_name.split('-')[1] + self._env = suite.load( + domain, + task, + task_kwargs, + environment_kwargs, + ) + + # placeholder to allow built in gymnasium rendering + self.render_mode = "rgb_array" + self.render_height = render_height + self.render_width = render_width + self.render_camera_id = render_camera_id + + self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) + self._norm_action_space = spaces.Box( + low=-1.0, + high=1.0, + shape=self._true_action_space.shape, + dtype=np.float32 + ) + + self._observation_space = _spec_to_box(self._env.observation_spec().values()) + self._action_space = _spec_to_box([self._env.action_spec()]) + self.action_repeat = action_repeat + + # set seed if provided with task_kwargs + if "random" in task_kwargs: + seed = task_kwargs["random"] + self._observation_space.seed(seed) + self._action_space.seed(seed) + + def __getattr__(self, name): + """Add this here so that we can easily access attributes of the underlying env""" + return getattr(self._env, name) + + @property + def observation_space(self): + return self._observation_space + + @property + def action_space(self): + return self._action_space + + @property + def reward_range(self): + """DMC always has a per-step reward range of (0, 1)""" + return 0, 1 + + def _convert_action(self, action): + action = action.astype(np.float64) + true_delta = self._true_action_space.high - self._true_action_space.low + norm_delta = self._norm_action_space.high - self._norm_action_space.low + action = (action - self._norm_action_space.low) / norm_delta + action = action * true_delta + self._true_action_space.low + action = action.astype(np.float32) + return action + + def step(self, action): + assert self._norm_action_space.contains(action) + action = self._convert_action(action) + assert self._true_action_space.contains(action) + action = np.clip(action, -1.0, 1.0) + reward = 0 + info = {} + for i in range(self.action_repeat): + timestep = self._env.step(action) + observation = _flatten_obs(timestep.observation) + reward += timestep.reward + termination = False # we never reach a goal + truncation = timestep.last() + if truncation: + return observation, reward, termination, truncation, info + return observation, reward, termination, truncation, info + + def reset(self, seed=None, options=None): + if seed is not None: + if not isinstance(seed, np.random.RandomState): + seed = np.random.RandomState(seed) + self._env.task._random = seed + if options: + logging.warn("Currently doing nothing with options={:}".format(options)) + timestep = self._env.reset() + observation = _flatten_obs(timestep.observation) + info = {} + return observation, info + + def render(self, height=None, width=None, camera_id=None): + height = height or self.render_height + width = width or self.render_width + camera_id = camera_id or self.render_camera_id + return self._env.physics.render(height=height, width=width, camera_id=camera_id) + + +def make_env_dmc(env_name: str, action_repeat: int = 1) -> Env: + env = DMCGym(env_name=env_name, action_repeat=action_repeat) + env = RescaleAction(env, -1.0, 1.0) + env = FlattenObservation(env) + return env + + \ No newline at end of file diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..8ec1325 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -7,6 +7,8 @@ from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC +from sbx.bro import BRO + # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") @@ -30,4 +32,5 @@ def DroQ(*args, **kwargs): "SAC", "TD3", "TQC", + "BRO", ] diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 37b4e3b..70b05c6 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -19,6 +19,7 @@ from sbx.bro.policies import BROPolicy + class EntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -26,8 +27,7 @@ class EntropyCoef(nn.Module): def __call__(self) -> jnp.ndarray: log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) return jnp.exp(log_ent_coef) - - + class ConstantEntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -38,7 +38,27 @@ def __call__(self) -> float: self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) return self.ent_coef_init - +@jax.jit +def _get_stats( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + key: jax.Array, +): + key, dropout_key, noise_key = jax.random.split(key, 3) + dist = actor_state.apply_fn(actor_state.params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean() + class BRO(OffPolicyAlgorithmJax): policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] "MlpPolicy": BROPolicy, @@ -109,11 +129,13 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy + self.init_key = jax.random.PRNGKey(seed) self.n_quantiles = n_quantiles taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] + self.distributional = True if self.n_quantiles > 1 else False if _init_setup_model: self._setup_model() @@ -121,6 +143,61 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + if not hasattr(self, "policy") or self.policy is None: + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + self.n_quantiles, + **self.policy_kwargs, + ) + + assert isinstance(self.qf_learning_rate, float) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam( + learning_rate=self.learning_rate, b1=0.5 + ), + ) + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + def reset(self): if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -132,7 +209,7 @@ def _setup_model(self) -> None: assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) @@ -163,7 +240,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.learning_rate, b1=0.5 ), ) @@ -238,10 +315,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - self.logger.record("train/actor_loss", actor_loss_value.item()) - self.logger.record("train/critic_loss", qf_loss_value.item()) - self.logger.record("train/ent_coef", ent_coef_value.item()) + return { + 'actor_loss': actor_loss_value.item(), + 'critic_loss': qf_loss_value.item(), + 'ent_loss': ent_coef_value.item(), + } + #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + #self.logger.record("train/actor_loss", actor_loss_value.item()) + #self.logger.record("train/critic_loss", qf_loss_value.item()) + #self.logger.record("train/ent_coef", ent_coef_value.item()) @staticmethod @jax.jit @@ -425,7 +507,40 @@ def update_actor_and_temperature( ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + def get_stats(self, batch_size): + data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) # type: ignore[attr-defined] + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + # Convert to numpy + data = ReplayBufferSamplesNp( # type: ignore[assignment] + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + q, a, temp, ent = _get_stats( + self.policy.actor_state, + self.policy.qf_state, + self.ent_coef_state, + obs, + self.key, + ) + + return { + 'q': q.mean().item(), + 'a': a.item(), + 'temp': temp.item(), + 'entropy': ent.item()} + @classmethod @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) def _train( diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 5297d6e..dc2a56e 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -21,7 +21,7 @@ class BroNetBlock(nn.Module): activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: out = nn.Dense(self.n_units)(x) out = nn.LayerNorm()(out) out = self.activation_fn(out) @@ -74,10 +74,10 @@ class Critic(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: x = Flatten()(x) - x = jnp.concatenate([x, action], -1) - x = BroNet(self.net_arch, self.activation_fn)(x) - x = nn.Dense(self.n_quantiles)(x) - return x + out = jnp.concatenate([x, action], -1) + out = BroNet(self.net_arch, self.activation_fn)(out) + out = nn.Dense(self.n_quantiles)(out) + return out class VectorCritic(nn.Module): net_arch: Sequence[int] @@ -117,7 +117,7 @@ def __init__( n_quantiles: int = 100, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, - layer_norm: bool = False, + layer_norm: bool = True, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, # Note: most gSDE parameters are not used @@ -128,7 +128,7 @@ def __init__( features_extractor_class=None, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, @@ -139,7 +139,7 @@ def __init__( # but shows only little overall improvement. optimizer_kwargs = {} if optimizer_class in [optax.adam, optax.adamw]: - optimizer_kwargs["b1"] = 0.5 + pass super().__init__( observation_space, @@ -163,7 +163,7 @@ def __init__( else: self.net_arch_pi = [256] # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR - self.net_arch_qf = [1024, 1024] + self.net_arch_qf = [512, 512] print(self.net_arch_qf) self.n_critics = n_critics self.use_sde = use_sde diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ba0b9ed..11e6c8c 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -71,7 +71,7 @@ def __init__( support_multi_env=support_multi_env, ) # Will be updated later - self.key = jax.random.PRNGKey(0) + self.key = jax.random.PRNGKey(seed) # Note: we do not allow schedule for it self.qf_learning_rate = qf_learning_rate @@ -115,6 +115,7 @@ def _setup_model(self) -> None: device="cpu", # force cpu device to easy torch -> numpy conversion n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, + handle_timeout_termination=False, **replay_buffer_kwargs, ) # Convert train freq parameter to TrainFreq object diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh new file mode 100644 index 0000000..f5e59bf --- /dev/null +++ b/scripts/dmc/acrobot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=acrobot-swingup + +wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh new file mode 100644 index 0000000..9c025df --- /dev/null +++ b/scripts/dmc/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=cheetah-run + +wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh new file mode 100644 index 0000000..45b27ea --- /dev/null +++ b/scripts/dmc/dog_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-run + +wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh new file mode 100644 index 0000000..07a4584 --- /dev/null +++ b/scripts/dmc/dog_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-stand + +wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh new file mode 100644 index 0000000..d7b1411 --- /dev/null +++ b/scripts/dmc/dog_trot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-trot + +wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh new file mode 100644 index 0000000..4221da2 --- /dev/null +++ b/scripts/dmc/dog_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-walk + +wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh new file mode 100644 index 0000000..8a10360 --- /dev/null +++ b/scripts/dmc/finger.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=finger-turn_hard + +wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh new file mode 100644 index 0000000..7dca0fa --- /dev/null +++ b/scripts/dmc/fish.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=fish-swim + +wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh new file mode 100644 index 0000000..fea40d9 --- /dev/null +++ b/scripts/dmc/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=hopper-hop + +wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh new file mode 100644 index 0000000..f434d0b --- /dev/null +++ b/scripts/dmc/humanoid_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-run + +wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh new file mode 100644 index 0000000..e22bff0 --- /dev/null +++ b/scripts/dmc/humanoid_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-stand + +wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh new file mode 100644 index 0000000..64a2a37 --- /dev/null +++ b/scripts/dmc/humanoid_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-walk + +wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh new file mode 100644 index 0000000..0e54c35 --- /dev/null +++ b/scripts/dmc/pendulum.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=pendulum-swingup + +wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh new file mode 100644 index 0000000..e91ba25 --- /dev/null +++ b/scripts/dmc/quadruped.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=quadruped-run + +wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh new file mode 100644 index 0000000..3da6aa7 --- /dev/null +++ b/scripts/dmc/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=walker-run + +wait diff --git a/scripts/gym/ant.sh b/scripts/gym/ant.sh new file mode 100644 index 0000000..1d63aee --- /dev/null +++ b/scripts/gym/ant.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Ant-v4 + +wait diff --git a/scripts/gym/cheetah.sh b/scripts/gym/cheetah.sh new file mode 100644 index 0000000..63108f6 --- /dev/null +++ b/scripts/gym/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=HalfCheetah-v4 + +wait diff --git a/scripts/gym/hopper.sh b/scripts/gym/hopper.sh new file mode 100644 index 0000000..3eeca18 --- /dev/null +++ b/scripts/gym/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Hopper-v4 + +wait diff --git a/scripts/gym/humanoid.sh b/scripts/gym/humanoid.sh new file mode 100644 index 0000000..7d2a7d2 --- /dev/null +++ b/scripts/gym/humanoid.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Humanoid-v4 + +wait diff --git a/scripts/gym/walker.sh b/scripts/gym/walker.sh new file mode 100644 index 0000000..8a3e68c --- /dev/null +++ b/scripts/gym/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Walker2d-v4 + +wait diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh deleted file mode 100755 index 9ecb207..0000000 --- a/scripts/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" diff --git a/train_torch.py b/train_torch.py new file mode 100644 index 0000000..f747f7d --- /dev/null +++ b/train_torch.py @@ -0,0 +1,116 @@ +import gymnasium as gym +from sbx.bro.bro import BRO +import numpy as np +from make_dmc import make_env_dmc +import wandb + +from absl import app, flags + +flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') +flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') +flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') +flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') +flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') +flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') +flags.DEFINE_integer('n_quantiles', 1, 'Number of training steps.') +flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') +flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') +FLAGS = flags.FLAGS + +''' +class flags: + benchmark: str = 'dmc' + env_name: str = "cheetah-run" + learning_starts: int = 9999 + training_steps: int = 10_000 + seed: int = 0 + batch_size: int = 128 + gradient_steps: int = 2 + use_wandb: bool = False + n_quantiles: int = 100 + eval_freq: int = 5000 + num_episodes: int = 5 +FLAGS = flags() +''' + +def evaluate(env, model, num_episodes): + returns = np.zeros(num_episodes) + for episode in range(num_episodes): + not_done = True + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + ret = 0 + while not_done: + action = model.policy.forward(obs, deterministic=True)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + obs = next_obs + ret += reward + if term or trun : + not_done = False + returns[episode] = ret + return {'return_eval': returns.mean()} + +def log_to_wandb(step, infos): + dict_to_log = {'timestep': step} + for info_key in infos: + dict_to_log[f'{info_key}'] = infos[info_key] + wandb.log(dict_to_log, step=step) + +def get_env(benchmark, env_name): + if benchmark == 'gym': + return gym.make(FLAGS.env_name) + else: + return make_env_dmc(env_name=FLAGS.env_name, action_repeat=1) + +def main(_): + SEED = np.random.randint(1e7) + wandb.init( + config=FLAGS, + entity='naumix', + project='BRO_SBX', + group=f'{FLAGS.env_name}', + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}_repro' + ) + + env = get_env(FLAGS.benchmark, FLAGS.env_name) + eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, gradient_steps=FLAGS.gradient_steps) + np.random.seed(SEED) + + reset_list = [15000] + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + + for i in range(1, FLAGS.training_steps+1): + if i <= FLAGS.learning_starts: + action = env.action_space.sample() + else: + action = model.policy.forward(obs, deterministic=False)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + + done = 1.0 if (term and not trun) else 0.0 + + model.replay_buffer.add(obs, next_obs, action, reward, done, info) + if term or trun: + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + else: + obs = next_obs + + if i in reset_list: + model.reset() + + if i >= FLAGS.learning_starts: + train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) + + if i % FLAGS.eval_freq == 0: + eval_info = evaluate(eval_env, model, FLAGS.num_episodes) + stat_info = model.get_stats(FLAGS.batch_size) + info = {**eval_info, **train_info, **stat_info} + #print(eval_info) + log_to_wandb(i, info) + +if __name__ == '__main__': + app.run(main)