From c242e9cb9359fa43d3e42c77285561a0b28478f7 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:15:37 -0700 Subject: [PATCH 01/18] Update ppo.py To resolve issue #207 in cleanrl, extra advantage code not needed --- cleanrl/ppo.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo.py b/cleanrl/ppo.py index 0c98eae3..f6b04720 100644 --- a/cleanrl/ppo.py +++ b/cleanrl/ppo.py @@ -216,30 +216,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 17705424efb446bebc8c45c385893f4850ff462a Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:17:19 -0700 Subject: [PATCH 02/18] Update ppo_atari.py To resolve issue #207 in cleanrl, extra advantage calc code unnecessary --- cleanrl/ppo_atari.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_atari.py b/cleanrl/ppo_atari.py index 11d62f3c..0ba1e5a2 100644 --- a/cleanrl/ppo_atari.py +++ b/cleanrl/ppo_atari.py @@ -232,30 +232,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 595d48ce71f215d04fbe75d2f392678dd202a580 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:23:52 -0700 Subject: [PATCH 03/18] Update ppo_atari_envpool.py Updated to resolve issue #207, unncessary additional advantage calc code --- cleanrl/ppo_atari_envpool.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_atari_envpool.py b/cleanrl/ppo_atari_envpool.py index 7ff9ab0f..2ba2464e 100644 --- a/cleanrl/ppo_atari_envpool.py +++ b/cleanrl/ppo_atari_envpool.py @@ -259,30 +259,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 46a1839120d0f0e78fe2711f3d655acff8d9f8a5 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:24:41 -0700 Subject: [PATCH 04/18] Update ppo_continuous_action.py Updated to resolve issue #207 --- cleanrl/ppo_continuous_action.py | 36 +++++++++++--------------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_continuous_action.py b/cleanrl/ppo_continuous_action.py index d1bf83fd..1bac9b8e 100644 --- a/cleanrl/ppo_continuous_action.py +++ b/cleanrl/ppo_continuous_action.py @@ -224,30 +224,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 465dee6cfb1371b551365732fae7d2afcfdb97d7 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:26:43 -0700 Subject: [PATCH 05/18] Update ppo_atari_lstm.py Updated to resolve issue #207, unnecessary additional advantage calc code for ppo implementations --- cleanrl/ppo_atari_lstm.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_atari_lstm.py b/cleanrl/ppo_atari_lstm.py index 9c0f9bc9..dff3c91b 100644 --- a/cleanrl/ppo_atari_lstm.py +++ b/cleanrl/ppo_atari_lstm.py @@ -268,30 +268,18 @@ def get_action_and_value(self, x, lstm_state, done, action=None): next_lstm_state, next_done, ).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From fd3665aa514826b997a0fd55d520074e7511d9db Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:28:05 -0700 Subject: [PATCH 06/18] Update ppo_pettingzoo_ma_atari.py Updated to resolve issue #207 --- cleanrl/ppo_pettingzoo_ma_atari.py | 36 ++++++++++-------------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index 7fdd4056..e9345ec9 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -219,30 +219,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From b0994b7f468d7693768caa11f270233e7d449d22 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:28:42 -0700 Subject: [PATCH 07/18] Update ppo_procgen.py Updated to resolve issue #207 --- cleanrl/ppo_procgen.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/cleanrl/ppo_procgen.py b/cleanrl/ppo_procgen.py index 5e53870e..83e73a04 100644 --- a/cleanrl/ppo_procgen.py +++ b/cleanrl/ppo_procgen.py @@ -249,30 +249,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From a4dfadf544c3a0a536e7e284d8f8994574202a45 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:01:49 -0700 Subject: [PATCH 08/18] GAE revisions #207 --- cleanrl/ppo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo.py b/cleanrl/ppo.py index f6b04720..4a7c2999 100644 --- a/cleanrl/ppo.py +++ b/cleanrl/ppo.py @@ -48,8 +48,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From f80331288e51efbcf1b86f2bd68106c9573dab81 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:02:16 -0700 Subject: [PATCH 09/18] GAE revisions #207 --- cleanrl/ppo_atari.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_atari.py b/cleanrl/ppo_atari.py index 0ba1e5a2..14be7a47 100644 --- a/cleanrl/ppo_atari.py +++ b/cleanrl/ppo_atari.py @@ -55,8 +55,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From 0a7df4509ba7f772d9944c1a351f862add00707a Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:02:55 -0700 Subject: [PATCH 10/18] GAE revisions #207 --- cleanrl/ppo_atari_envpool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_atari_envpool.py b/cleanrl/ppo_atari_envpool.py index 2ba2464e..838579bb 100644 --- a/cleanrl/ppo_atari_envpool.py +++ b/cleanrl/ppo_atari_envpool.py @@ -49,8 +49,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From 54a1d979abd1bb43fddd8ee9e449bce838697ff1 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:03:21 -0700 Subject: [PATCH 11/18] GAE revisions #207 --- cleanrl/ppo_atari_lstm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_atari_lstm.py b/cleanrl/ppo_atari_lstm.py index dff3c91b..a90aa4ce 100644 --- a/cleanrl/ppo_atari_lstm.py +++ b/cleanrl/ppo_atari_lstm.py @@ -55,8 +55,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From a35b560e103964bf17014ddb45b8a4e95eb15a8b Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:04:58 -0700 Subject: [PATCH 12/18] GAE revisions #207 --- cleanrl/ppo_atari_multigpu.py | 38 +++++++++++------------------------ 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/cleanrl/ppo_atari_multigpu.py b/cleanrl/ppo_atari_multigpu.py index 22763871..8955e129 100644 --- a/cleanrl/ppo_atari_multigpu.py +++ b/cleanrl/ppo_atari_multigpu.py @@ -57,8 +57,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, @@ -274,30 +272,18 @@ def get_action_and_value(self, x, action=None): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages = torch.zeros_like(rewards).to(device) - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From b5cf6b67aa91421d193065e39b057b7bd57869b3 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:05:27 -0700 Subject: [PATCH 13/18] GAE revisions #207 --- cleanrl/ppo_continuous_action.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_continuous_action.py b/cleanrl/ppo_continuous_action.py index 1bac9b8e..80086348 100644 --- a/cleanrl/ppo_continuous_action.py +++ b/cleanrl/ppo_continuous_action.py @@ -48,8 +48,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From 84e20ba575685f4b1cc85473798fea9f746cdec1 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:06:09 -0700 Subject: [PATCH 14/18] GAE revisions #207 --- cleanrl/ppo_pettingzoo_ma_atari.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index e9345ec9..bc51c703 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -49,8 +49,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From e21863c4d56968b40a6c3e056de25862155fa697 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:06:38 -0700 Subject: [PATCH 15/18] GAE revisions #207 --- cleanrl/ppo_procgen.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cleanrl/ppo_procgen.py b/cleanrl/ppo_procgen.py index 83e73a04..9a93eb0c 100644 --- a/cleanrl/ppo_procgen.py +++ b/cleanrl/ppo_procgen.py @@ -48,8 +48,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.999, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, From 33f06836b3ac7977acf340b3913742277ba57c86 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:09:47 -0700 Subject: [PATCH 16/18] GAE revisions #207 --- .../ppo_continuous_action_isaacgym.py | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py b/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py index f996dfde..ddf3cf89 100644 --- a/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py +++ b/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py @@ -77,8 +77,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, @@ -303,30 +301,18 @@ def observation(self, obs): # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) - if args.gae: - advantages[:] = 0 - lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - nextvalues = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - nextvalues = values[t + 1] - delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] - advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam - returns = advantages + values - else: - returns = torch.zeros_like(rewards).to(device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - nextnonterminal = 1.0 - next_done - next_return = next_value - else: - nextnonterminal = 1.0 - dones[t + 1] - next_return = returns[t + 1] - returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return - advantages = returns - values + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 6cc5efd0f11a2cfc56a4c321030263be2b1ff827 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 16:14:25 -0700 Subject: [PATCH 17/18] GAE revisions #207 --- cleanrl/ppo_rnd_envpool.py | 67 +++++++++++++------------------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/cleanrl/ppo_rnd_envpool.py b/cleanrl/ppo_rnd_envpool.py index 4529475b..1098ea12 100644 --- a/cleanrl/ppo_rnd_envpool.py +++ b/cleanrl/ppo_rnd_envpool.py @@ -49,8 +49,6 @@ def parse_args(): help="the number of steps to run in each environment per policy rollout") parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Use GAE for advantage computation") parser.add_argument("--gamma", type=float, default=0.999, help="the discount factor gamma") parser.add_argument("--gae-lambda", type=float, default=0.95, @@ -413,50 +411,31 @@ def update(self, rews): with torch.no_grad(): next_value_ext, next_value_int = agent.get_value(next_obs) next_value_ext, next_value_int = next_value_ext.reshape(1, -1), next_value_int.reshape(1, -1) - if args.gae: - ext_advantages = torch.zeros_like(rewards, device=device) - int_advantages = torch.zeros_like(curiosity_rewards, device=device) - ext_lastgaelam = 0 - int_lastgaelam = 0 - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - ext_nextnonterminal = 1.0 - next_done - int_nextnonterminal = 1.0 - ext_nextvalues = next_value_ext - int_nextvalues = next_value_int - else: - ext_nextnonterminal = 1.0 - dones[t + 1] - int_nextnonterminal = 1.0 - ext_nextvalues = ext_values[t + 1] - int_nextvalues = int_values[t + 1] - ext_delta = rewards[t] + args.gamma * ext_nextvalues * ext_nextnonterminal - ext_values[t] - int_delta = curiosity_rewards[t] + args.int_gamma * int_nextvalues * int_nextnonterminal - int_values[t] - ext_advantages[t] = ext_lastgaelam = ( - ext_delta + args.gamma * args.gae_lambda * ext_nextnonterminal * ext_lastgaelam + ext_advantages = torch.zeros_like(rewards, device=device) + int_advantages = torch.zeros_like(curiosity_rewards, device=device) + ext_lastgaelam = 0 + int_lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + ext_nextnonterminal = 1.0 - next_done + int_nextnonterminal = 1.0 + ext_nextvalues = next_value_ext + int_nextvalues = next_value_int + else: + ext_nextnonterminal = 1.0 - dones[t + 1] + int_nextnonterminal = 1.0 + ext_nextvalues = ext_values[t + 1] + int_nextvalues = int_values[t + 1] + ext_delta = rewards[t] + args.gamma * ext_nextvalues * ext_nextnonterminal - ext_values[t] + int_delta = curiosity_rewards[t] + args.int_gamma * int_nextvalues * int_nextnonterminal - int_values[t] + ext_advantages[t] = ext_lastgaelam = ( + ext_delta + args.gamma * args.gae_lambda * ext_nextnonterminal * ext_lastgaelam ) - int_advantages[t] = int_lastgaelam = ( - int_delta + args.int_gamma * args.gae_lambda * int_nextnonterminal * int_lastgaelam + int_advantages[t] = int_lastgaelam = ( + int_delta + args.int_gamma * args.gae_lambda * int_nextnonterminal * int_lastgaelam ) - ext_returns = ext_advantages + ext_values - int_returns = int_advantages + int_values - else: - ext_returns = torch.zeros_like(rewards, device=device) - int_returns = torch.zeros_like(curiosity_rewards, device=device) - for t in reversed(range(args.num_steps)): - if t == args.num_steps - 1: - ext_nextnonterminal = 1.0 - next_done - int_nextnonterminal = 1.0 - ext_next_return = next_value_ext - int_next_return = next_value_int - else: - ext_nextnonterminal = 1.0 - dones[t + 1] - int_nextnonterminal = 1.0 - ext_next_return = ext_returns[t + 1] - int_next_return = int_returns[t + 1] - ext_returns[t] = rewards[t] + args.gamma * ext_nextnonterminal * ext_next_return - int_returns[t] = curiosity_rewards[t] + args.int_gamma * int_nextnonterminal * int_next_return - ext_advantages = ext_returns - ext_values - int_advantages = int_returns - int_values + ext_returns = ext_advantages + ext_values + int_returns = int_advantages + int_values # flatten the batch b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) From 06c5507048b05576a3359496a54018ccf84f8e32 Mon Sep 17 00:00:00 2001 From: bragajj <39658109+bragajj@users.noreply.github.com> Date: Mon, 3 Oct 2022 17:51:31 -0700 Subject: [PATCH 18/18] Update ppo_rnd_envpool.py Fixed styling of lines 432-436 --- cleanrl/ppo_rnd_envpool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_rnd_envpool.py b/cleanrl/ppo_rnd_envpool.py index 1098ea12..6d176dab 100644 --- a/cleanrl/ppo_rnd_envpool.py +++ b/cleanrl/ppo_rnd_envpool.py @@ -430,10 +430,10 @@ def update(self, rews): int_delta = curiosity_rewards[t] + args.int_gamma * int_nextvalues * int_nextnonterminal - int_values[t] ext_advantages[t] = ext_lastgaelam = ( ext_delta + args.gamma * args.gae_lambda * ext_nextnonterminal * ext_lastgaelam - ) + ) int_advantages[t] = int_lastgaelam = ( int_delta + args.int_gamma * args.gae_lambda * int_nextnonterminal * int_lastgaelam - ) + ) ext_returns = ext_advantages + ext_values int_returns = int_advantages + int_values