From 2c0322be63805f523ada5582183aeaceb197d6cc Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 19:00:19 +0100 Subject: [PATCH 1/6] add --- sbx/__init__.py | 3 +++ sbx/bro/bro.py | 9 +++++---- sbx/common/off_policy_algorithm.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..8ec1325 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -7,6 +7,8 @@ from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC +from sbx.bro import BRO + # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") @@ -30,4 +32,5 @@ def DroQ(*args, **kwargs): "SAC", "TD3", "TQC", + "BRO", ] diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 37b4e3b..fe1c1e9 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -19,6 +19,7 @@ from sbx.bro.policies import BROPolicy + class EntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -238,10 +239,10 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - self.logger.record("train/actor_loss", actor_loss_value.item()) - self.logger.record("train/critic_loss", qf_loss_value.item()) - self.logger.record("train/ent_coef", ent_coef_value.item()) + #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + #self.logger.record("train/actor_loss", actor_loss_value.item()) + #self.logger.record("train/critic_loss", qf_loss_value.item()) + #self.logger.record("train/ent_coef", ent_coef_value.item()) @staticmethod @jax.jit diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ba0b9ed..fddf57b 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -115,6 +115,7 @@ def _setup_model(self) -> None: device="cpu", # force cpu device to easy torch -> numpy conversion n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, + handle_timeout_termination=False, **replay_buffer_kwargs, ) # Convert train freq parameter to TrainFreq object From f23b80f17172e0250a9dfef0976dd4a5422c1d48 Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 22:44:59 +0100 Subject: [PATCH 2/6] add scripts --- make_dmc.py | 157 ++++++++++++++++++++++++++++++++++ sbx/bro/bro.py | 5 ++ scripts/dmc/acrobot.sh | 13 +++ scripts/dmc/cheetah.sh | 13 +++ scripts/dmc/dog_run.sh | 13 +++ scripts/dmc/dog_stand.sh | 13 +++ scripts/dmc/dog_trot.sh | 13 +++ scripts/dmc/dog_walk.sh | 13 +++ scripts/dmc/finger.sh | 13 +++ scripts/dmc/fish.sh | 13 +++ scripts/dmc/hopper.sh | 13 +++ scripts/dmc/humanoid_run.sh | 13 +++ scripts/dmc/humanoid_stand.sh | 13 +++ scripts/dmc/humanoid_walk.sh | 13 +++ scripts/dmc/pendulum.sh | 13 +++ scripts/dmc/quadruped.sh | 13 +++ scripts/dmc/walker.sh | 13 +++ scripts/run_tests.sh | 2 - train.py | 110 ++++++++++++++++++++++++ 19 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 make_dmc.py create mode 100644 scripts/dmc/acrobot.sh create mode 100644 scripts/dmc/cheetah.sh create mode 100644 scripts/dmc/dog_run.sh create mode 100644 scripts/dmc/dog_stand.sh create mode 100644 scripts/dmc/dog_trot.sh create mode 100644 scripts/dmc/dog_walk.sh create mode 100644 scripts/dmc/finger.sh create mode 100644 scripts/dmc/fish.sh create mode 100644 scripts/dmc/hopper.sh create mode 100644 scripts/dmc/humanoid_run.sh create mode 100644 scripts/dmc/humanoid_stand.sh create mode 100644 scripts/dmc/humanoid_walk.sh create mode 100644 scripts/dmc/pendulum.sh create mode 100644 scripts/dmc/quadruped.sh create mode 100644 scripts/dmc/walker.sh delete mode 100755 scripts/run_tests.sh create mode 100644 train.py diff --git a/make_dmc.py b/make_dmc.py new file mode 100644 index 0000000..57374dd --- /dev/null +++ b/make_dmc.py @@ -0,0 +1,157 @@ +#adapted from https://github.com/imgeorgiev/dmc2gymnasium + +import logging +import numpy as np + +from dm_control import suite +from dm_env import specs +from gymnasium.core import Env +from gymnasium.spaces import Box +from gymnasium import spaces +from gymnasium.wrappers import FlattenObservation, RescaleAction + +def _spec_to_box(spec, dtype=np.float32): + def extract_min_max(s): + assert s.dtype == np.float64 or s.dtype == np.float32 + dim = int(np.prod(s.shape)) + if type(s) == specs.Array: + bound = np.inf * np.ones(dim, dtype=np.float32) + return -bound, bound + elif type(s) == specs.BoundedArray: + zeros = np.zeros(dim, dtype=np.float32) + return s.minimum + zeros, s.maximum + zeros + else: + logging.error("Unrecognized type") + mins, maxs = [], [] + for s in spec: + mn, mx = extract_min_max(s) + mins.append(mn) + maxs.append(mx) + low = np.concatenate(mins, axis=0).astype(dtype) + high = np.concatenate(maxs, axis=0).astype(dtype) + assert low.shape == high.shape + return Box(low, high, dtype=dtype) + + +def _flatten_obs(obs, dtype=np.float32): + obs_pieces = [] + for v in obs.values(): + flat = np.array([v]) if np.isscalar(v) else v.ravel() + obs_pieces.append(flat) + return np.concatenate(obs_pieces, axis=0).astype(dtype) + + +class DMCGym(Env): + def __init__( + self, + env_name, + task_kwargs={}, + environment_kwargs={}, + #rendering="egl", + render_height=64, + render_width=64, + render_camera_id=0, + action_repeat=1 + ): + domain = env_name.split('-')[0] + task = env_name.split('-')[1] + self._env = suite.load( + domain, + task, + task_kwargs, + environment_kwargs, + ) + + # placeholder to allow built in gymnasium rendering + self.render_mode = "rgb_array" + self.render_height = render_height + self.render_width = render_width + self.render_camera_id = render_camera_id + + self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) + self._norm_action_space = spaces.Box( + low=-1.0, + high=1.0, + shape=self._true_action_space.shape, + dtype=np.float32 + ) + + self._observation_space = _spec_to_box(self._env.observation_spec().values()) + self._action_space = _spec_to_box([self._env.action_spec()]) + self.action_repeat = action_repeat + + # set seed if provided with task_kwargs + if "random" in task_kwargs: + seed = task_kwargs["random"] + self._observation_space.seed(seed) + self._action_space.seed(seed) + + def __getattr__(self, name): + """Add this here so that we can easily access attributes of the underlying env""" + return getattr(self._env, name) + + @property + def observation_space(self): + return self._observation_space + + @property + def action_space(self): + return self._action_space + + @property + def reward_range(self): + """DMC always has a per-step reward range of (0, 1)""" + return 0, 1 + + def _convert_action(self, action): + action = action.astype(np.float64) + true_delta = self._true_action_space.high - self._true_action_space.low + norm_delta = self._norm_action_space.high - self._norm_action_space.low + action = (action - self._norm_action_space.low) / norm_delta + action = action * true_delta + self._true_action_space.low + action = action.astype(np.float32) + return action + + def step(self, action): + assert self._norm_action_space.contains(action) + action = self._convert_action(action) + assert self._true_action_space.contains(action) + action = np.clip(action, -1.0, 1.0) + reward = 0 + info = {} + for i in range(self.action_repeat): + timestep = self._env.step(action) + observation = _flatten_obs(timestep.observation) + reward += timestep.reward + termination = False # we never reach a goal + truncation = timestep.last() + if truncation: + return observation, reward, termination, truncation, info + return observation, reward, termination, truncation, info + + def reset(self, seed=None, options=None): + if seed is not None: + if not isinstance(seed, np.random.RandomState): + seed = np.random.RandomState(seed) + self._env.task._random = seed + if options: + logging.warn("Currently doing nothing with options={:}".format(options)) + timestep = self._env.reset() + observation = _flatten_obs(timestep.observation) + info = {} + return observation, info + + def render(self, height=None, width=None, camera_id=None): + height = height or self.render_height + width = width or self.render_width + camera_id = camera_id or self.render_camera_id + return self._env.physics.render(height=height, width=width, camera_id=camera_id) + + +def make_env_dmc(env_name: str, action_repeat: int = 1) -> Env: + env = DMCGym(env_name=env_name, action_repeat=action_repeat) + env = RescaleAction(env, -1.0, 1.0) + env = FlattenObservation(env) + return env + + \ No newline at end of file diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index fe1c1e9..a3b29c3 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -239,6 +239,11 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps + return { + 'actor_loss': actor_loss_value.item(), + 'critic_loss': qf_loss_value.item(), + 'ent_coef': ent_coef_value.item(), + } #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") #self.logger.record("train/actor_loss", actor_loss_value.item()) #self.logger.record("train/critic_loss", qf_loss_value.item()) diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh new file mode 100644 index 0000000..50070c2 --- /dev/null +++ b/scripts/dmc/acrobot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup + +wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh new file mode 100644 index 0000000..c6e8b3f --- /dev/null +++ b/scripts/dmc/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run + +wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh new file mode 100644 index 0000000..88fbfa7 --- /dev/null +++ b/scripts/dmc/dog_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run + +wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh new file mode 100644 index 0000000..7184f93 --- /dev/null +++ b/scripts/dmc/dog_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand + +wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh new file mode 100644 index 0000000..701d141 --- /dev/null +++ b/scripts/dmc/dog_trot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot + +wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh new file mode 100644 index 0000000..24f9635 --- /dev/null +++ b/scripts/dmc/dog_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk + +wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh new file mode 100644 index 0000000..8cf0d21 --- /dev/null +++ b/scripts/dmc/finger.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard + +wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh new file mode 100644 index 0000000..c18e8e3 --- /dev/null +++ b/scripts/dmc/fish.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim + +wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh new file mode 100644 index 0000000..7ecd2cc --- /dev/null +++ b/scripts/dmc/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop + +wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh new file mode 100644 index 0000000..7f33786 --- /dev/null +++ b/scripts/dmc/humanoid_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run + +wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh new file mode 100644 index 0000000..16fa76f --- /dev/null +++ b/scripts/dmc/humanoid_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand + +wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh new file mode 100644 index 0000000..651c2a3 --- /dev/null +++ b/scripts/dmc/humanoid_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk + +wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh new file mode 100644 index 0000000..5ca5fd4 --- /dev/null +++ b/scripts/dmc/pendulum.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup + +wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh new file mode 100644 index 0000000..5de95f1 --- /dev/null +++ b/scripts/dmc/quadruped.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run + +wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh new file mode 100644 index 0000000..3e97f7f --- /dev/null +++ b/scripts/dmc/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run + +wait diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh deleted file mode 100755 index 9ecb207..0000000 --- a/scripts/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" diff --git a/train.py b/train.py new file mode 100644 index 0000000..cc38cfc --- /dev/null +++ b/train.py @@ -0,0 +1,110 @@ +import gymnasium as gym +from sbx.bro.bro import BRO +import numpy as np +from make_dmc import make_env_dmc +import wandb + +from absl import app, flags + +flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') +flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') +flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') +flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') +flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') +flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') +flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') +flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') +flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') +FLAGS = flags.FLAGS + +''' +class flags: + env_name: str = "cheetah-run" + learning_starts: int = 5000 + training_steps: int = 10_000 + seed: int = 0 + batch_size: int = 128 + gradient_steps: int = 2 + use_wandb: bool = False + n_quantiles: int = 100 + eval_freq: int = 5000 + num_episodes: int = 5 +FLAGS = flags() +''' + +def evaluate(env, model, num_episodes): + returns = np.zeros(num_episodes) + for episode in range(num_episodes): + not_done = True + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + ret = 0 + while not_done: + action = model.policy.forward(obs, deterministic=True)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + obs = next_obs + ret += reward + if term or trun : + not_done = False + returns[episode] = ret + return {'return_eval': returns.mean()} + +def log_to_wandb(step, infos): + dict_to_log = {'timestep': step} + for info_key in infos: + dict_to_log[f'{info_key}'] = infos[info_key] + wandb.log(dict_to_log, step=step) + +def get_env(benchmark, env_name): + if benchmark == 'gym': + return gym.make(FLAGS.env_name) + else: + return make_env_dmc(env_name=FLAGS.env_name, action_repeat=1) + +def main(_): + SEED = np.random.randint(1e7) + wandb.init( + config=FLAGS, + entity='naumix', + project='BRO_SBX', + group=f'{FLAGS.env_name}_{SEED}', + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}' + ) + + env = get_env(FLAGS.benchmark, FLAGS.env_name) + eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED) + np.random.seed(SEED) + + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + + for i in range(1, FLAGS.training_steps+1): + if i <= FLAGS.learning_starts: + action = env.action_space.sample() + else: + action = model.policy.forward(obs, deterministic=False)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + + done = 1.0 if (term and not trun) else 0.0 + + model.replay_buffer.add(obs, next_obs, action, reward, done, info) + if term or trun: + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + else: + obs = next_obs + + if i >= FLAGS.learning_starts: + train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) + + if i % FLAGS.eval_freq == 0: + eval_info = evaluate(eval_env, model, FLAGS.num_episodes) + info = {**eval_info, **train_info} + #print(eval_info) + log_to_wandb(i, info) + +if __name__ == '__main__': + app.run(main) From e291f699c02f938ecfbccb5d1e160ba681c3ff68 Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 23:54:02 +0100 Subject: [PATCH 3/6] scripts --- scripts/dmc/acrobot.sh | 2 +- scripts/dmc/cheetah.sh | 2 +- scripts/dmc/dog_run.sh | 2 +- scripts/dmc/dog_stand.sh | 2 +- scripts/dmc/dog_trot.sh | 2 +- scripts/dmc/dog_walk.sh | 2 +- scripts/dmc/finger.sh | 2 +- scripts/dmc/fish.sh | 2 +- scripts/dmc/hopper.sh | 2 +- scripts/dmc/humanoid_run.sh | 2 +- scripts/dmc/humanoid_stand.sh | 2 +- scripts/dmc/humanoid_walk.sh | 2 +- scripts/dmc/pendulum.sh | 2 +- scripts/dmc/quadruped.sh | 2 +- scripts/dmc/walker.sh | 2 +- train.py => train_torch.py | 4 ++-- 16 files changed, 17 insertions(+), 17 deletions(-) rename train.py => train_torch.py (98%) diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh index 50070c2..f5e59bf 100644 --- a/scripts/dmc/acrobot.sh +++ b/scripts/dmc/acrobot.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup +python3 train_torch.py --env_name=acrobot-swingup wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh index c6e8b3f..9c025df 100644 --- a/scripts/dmc/cheetah.sh +++ b/scripts/dmc/cheetah.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run +python3 train_torch.py --env_name=cheetah-run wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh index 88fbfa7..45b27ea 100644 --- a/scripts/dmc/dog_run.sh +++ b/scripts/dmc/dog_run.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run +python3 train_torch.py --env_name=dog-run wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh index 7184f93..07a4584 100644 --- a/scripts/dmc/dog_stand.sh +++ b/scripts/dmc/dog_stand.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand +python3 train_torch.py --env_name=dog-stand wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh index 701d141..d7b1411 100644 --- a/scripts/dmc/dog_trot.sh +++ b/scripts/dmc/dog_trot.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot +python3 train_torch.py --env_name=dog-trot wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh index 24f9635..4221da2 100644 --- a/scripts/dmc/dog_walk.sh +++ b/scripts/dmc/dog_walk.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk +python3 train_torch.py --env_name=dog-walk wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh index 8cf0d21..8a10360 100644 --- a/scripts/dmc/finger.sh +++ b/scripts/dmc/finger.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard +python3 train_torch.py --env_name=finger-turn_hard wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh index c18e8e3..7dca0fa 100644 --- a/scripts/dmc/fish.sh +++ b/scripts/dmc/fish.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim +python3 train_torch.py --env_name=fish-swim wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh index 7ecd2cc..fea40d9 100644 --- a/scripts/dmc/hopper.sh +++ b/scripts/dmc/hopper.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop +python3 train_torch.py --env_name=hopper-hop wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh index 7f33786..f434d0b 100644 --- a/scripts/dmc/humanoid_run.sh +++ b/scripts/dmc/humanoid_run.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run +python3 train_torch.py --env_name=humanoid-run wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh index 16fa76f..e22bff0 100644 --- a/scripts/dmc/humanoid_stand.sh +++ b/scripts/dmc/humanoid_stand.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand +python3 train_torch.py --env_name=humanoid-stand wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh index 651c2a3..64a2a37 100644 --- a/scripts/dmc/humanoid_walk.sh +++ b/scripts/dmc/humanoid_walk.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk +python3 train_torch.py --env_name=humanoid-walk wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh index 5ca5fd4..0e54c35 100644 --- a/scripts/dmc/pendulum.sh +++ b/scripts/dmc/pendulum.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup +python3 train_torch.py --env_name=pendulum-swingup wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh index 5de95f1..e91ba25 100644 --- a/scripts/dmc/quadruped.sh +++ b/scripts/dmc/quadruped.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run +python3 train_torch.py --env_name=quadruped-run wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh index 3e97f7f..3da6aa7 100644 --- a/scripts/dmc/walker.sh +++ b/scripts/dmc/walker.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run +python3 train_torch.py --env_name=walker-run wait diff --git a/train.py b/train_torch.py similarity index 98% rename from train.py rename to train_torch.py index cc38cfc..ab106f5 100644 --- a/train.py +++ b/train_torch.py @@ -68,8 +68,8 @@ def main(_): config=FLAGS, entity='naumix', project='BRO_SBX', - group=f'{FLAGS.env_name}_{SEED}', - name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}' + group=f'{FLAGS.env_name}', + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}' ) env = get_env(FLAGS.benchmark, FLAGS.env_name) From c2ca2f9a5bb8259322078d9e465092f2a50caa5a Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 2 Nov 2024 18:13:22 +0100 Subject: [PATCH 4/6] add --- sbx/bro/bro.py | 121 +++++++++++++++++++++++++++-- sbx/bro/policies.py | 2 +- sbx/common/off_policy_algorithm.py | 2 +- scripts/gym/ant.sh | 13 ++++ scripts/gym/cheetah.sh | 13 ++++ scripts/gym/hopper.sh | 13 ++++ scripts/gym/walker.sh | 13 ++++ train_torch.py | 15 ++-- 8 files changed, 179 insertions(+), 13 deletions(-) create mode 100644 scripts/gym/ant.sh create mode 100644 scripts/gym/cheetah.sh create mode 100644 scripts/gym/hopper.sh create mode 100644 scripts/gym/walker.sh diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index a3b29c3..70b05c6 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -27,8 +27,7 @@ class EntropyCoef(nn.Module): def __call__(self) -> jnp.ndarray: log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) return jnp.exp(log_ent_coef) - - + class ConstantEntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -39,7 +38,27 @@ def __call__(self) -> float: self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) return self.ent_coef_init - +@jax.jit +def _get_stats( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + key: jax.Array, +): + key, dropout_key, noise_key = jax.random.split(key, 3) + dist = actor_state.apply_fn(actor_state.params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean() + class BRO(OffPolicyAlgorithmJax): policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] "MlpPolicy": BROPolicy, @@ -110,11 +129,13 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy + self.init_key = jax.random.PRNGKey(seed) self.n_quantiles = n_quantiles taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] + self.distributional = True if self.n_quantiles > 1 else False if _init_setup_model: self._setup_model() @@ -122,6 +143,61 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + if not hasattr(self, "policy") or self.policy is None: + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + self.n_quantiles, + **self.policy_kwargs, + ) + + assert isinstance(self.qf_learning_rate, float) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam( + learning_rate=self.learning_rate, b1=0.5 + ), + ) + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + def reset(self): if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -133,7 +209,7 @@ def _setup_model(self) -> None: assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) @@ -164,7 +240,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.learning_rate, b1=0.5 ), ) @@ -242,7 +318,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: return { 'actor_loss': actor_loss_value.item(), 'critic_loss': qf_loss_value.item(), - 'ent_coef': ent_coef_value.item(), + 'ent_loss': ent_coef_value.item(), } #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") #self.logger.record("train/actor_loss", actor_loss_value.item()) @@ -431,7 +507,40 @@ def update_actor_and_temperature( ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + def get_stats(self, batch_size): + data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) # type: ignore[attr-defined] + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + + # Convert to numpy + data = ReplayBufferSamplesNp( # type: ignore[assignment] + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + q, a, temp, ent = _get_stats( + self.policy.actor_state, + self.policy.qf_state, + self.ent_coef_state, + obs, + self.key, + ) + + return { + 'q': q.mean().item(), + 'a': a.item(), + 'temp': temp.item(), + 'entropy': ent.item()} + @classmethod @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) def _train( diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 5297d6e..1bee8bd 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -128,7 +128,7 @@ def __init__( features_extractor_class=None, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index fddf57b..11e6c8c 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -71,7 +71,7 @@ def __init__( support_multi_env=support_multi_env, ) # Will be updated later - self.key = jax.random.PRNGKey(0) + self.key = jax.random.PRNGKey(seed) # Note: we do not allow schedule for it self.qf_learning_rate = qf_learning_rate diff --git a/scripts/gym/ant.sh b/scripts/gym/ant.sh new file mode 100644 index 0000000..1d63aee --- /dev/null +++ b/scripts/gym/ant.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Ant-v4 + +wait diff --git a/scripts/gym/cheetah.sh b/scripts/gym/cheetah.sh new file mode 100644 index 0000000..63108f6 --- /dev/null +++ b/scripts/gym/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=HalfCheetah-v4 + +wait diff --git a/scripts/gym/hopper.sh b/scripts/gym/hopper.sh new file mode 100644 index 0000000..3eeca18 --- /dev/null +++ b/scripts/gym/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Hopper-v4 + +wait diff --git a/scripts/gym/walker.sh b/scripts/gym/walker.sh new file mode 100644 index 0000000..8a3e68c --- /dev/null +++ b/scripts/gym/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Walker2d-v4 + +wait diff --git a/train_torch.py b/train_torch.py index ab106f5..0aa57d1 100644 --- a/train_torch.py +++ b/train_torch.py @@ -8,9 +8,9 @@ flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') -flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') +flags.DEFINE_integer('learning_starts', 2000, 'Number of training steps to start training.') flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') -flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') +flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') @@ -74,9 +74,10 @@ def main(_): env = get_env(FLAGS.benchmark, FLAGS.env_name) eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) - model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, learning_starts=FLAGS.learning_starts, gradient_steps=FLAGS.gradient_steps) np.random.seed(SEED) + reset_list = [20000] obs, _ = env.reset(seed=np.random.randint(1e7)) obs = np.expand_dims(obs, axis=0) @@ -97,12 +98,16 @@ def main(_): else: obs = next_obs + if i in reset_list: + model.reset() + if i >= FLAGS.learning_starts: train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) - + if i % FLAGS.eval_freq == 0: eval_info = evaluate(eval_env, model, FLAGS.num_episodes) - info = {**eval_info, **train_info} + stat_info = model.get_stats(FLAGS.batch_size) + info = {**eval_info, **train_info, **stat_info} #print(eval_info) log_to_wandb(i, info) From 6599938b6cef194d00825b781752b1bbcb30e4e2 Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 2 Nov 2024 22:15:27 +0100 Subject: [PATCH 5/6] Update policies.py --- sbx/bro/policies.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 1bee8bd..576ef80 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -21,7 +21,7 @@ class BroNetBlock(nn.Module): activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: out = nn.Dense(self.n_units)(x) out = nn.LayerNorm()(out) out = self.activation_fn(out) @@ -74,10 +74,10 @@ class Critic(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: x = Flatten()(x) - x = jnp.concatenate([x, action], -1) - x = BroNet(self.net_arch, self.activation_fn)(x) - x = nn.Dense(self.n_quantiles)(x) - return x + out = jnp.concatenate([x, action], -1) + out = BroNet(self.net_arch, self.activation_fn)(out) + out = nn.Dense(self.n_quantiles)(out) + return out class VectorCritic(nn.Module): net_arch: Sequence[int] @@ -139,7 +139,7 @@ def __init__( # but shows only little overall improvement. optimizer_kwargs = {} if optimizer_class in [optax.adam, optax.adamw]: - optimizer_kwargs["b1"] = 0.5 + pass super().__init__( observation_space, @@ -163,7 +163,7 @@ def __init__( else: self.net_arch_pi = [256] # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR - self.net_arch_qf = [1024, 1024] + self.net_arch_qf = [512, 512] print(self.net_arch_qf) self.n_critics = n_critics self.use_sde = use_sde From 9e2c6dd31f6037e3631cd0395dbd2e3f741eab24 Mon Sep 17 00:00:00 2001 From: naumix Date: Sun, 3 Nov 2024 13:47:04 +0100 Subject: [PATCH 6/6] add --- sbx/bro/policies.py | 2 +- scripts/gym/humanoid.sh | 13 +++++++++++++ train_torch.py | 15 ++++++++------- 3 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 scripts/gym/humanoid.sh diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 576ef80..dc2a56e 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -117,7 +117,7 @@ def __init__( n_quantiles: int = 100, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, - layer_norm: bool = False, + layer_norm: bool = True, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, # Note: most gSDE parameters are not used diff --git a/scripts/gym/humanoid.sh b/scripts/gym/humanoid.sh new file mode 100644 index 0000000..7d2a7d2 --- /dev/null +++ b/scripts/gym/humanoid.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Humanoid-v4 + +wait diff --git a/train_torch.py b/train_torch.py index 0aa57d1..f747f7d 100644 --- a/train_torch.py +++ b/train_torch.py @@ -8,19 +8,20 @@ flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') -flags.DEFINE_integer('learning_starts', 2000, 'Number of training steps to start training.') +flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') -flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') +flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') -flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') +flags.DEFINE_integer('n_quantiles', 1, 'Number of training steps.') flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') FLAGS = flags.FLAGS ''' class flags: + benchmark: str = 'dmc' env_name: str = "cheetah-run" - learning_starts: int = 5000 + learning_starts: int = 9999 training_steps: int = 10_000 seed: int = 0 batch_size: int = 128 @@ -69,15 +70,15 @@ def main(_): entity='naumix', project='BRO_SBX', group=f'{FLAGS.env_name}', - name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}' + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}_repro' ) env = get_env(FLAGS.benchmark, FLAGS.env_name) eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) - model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, learning_starts=FLAGS.learning_starts, gradient_steps=FLAGS.gradient_steps) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, gradient_steps=FLAGS.gradient_steps) np.random.seed(SEED) - reset_list = [20000] + reset_list = [15000] obs, _ = env.reset(seed=np.random.randint(1e7)) obs = np.expand_dims(obs, axis=0)