Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Removing the regular advantage calculation in PPO #207

Closed
Tracked by #206
vwxyzjn opened this issue Jun 20, 2022 · 2 comments
Closed
Tracked by #206

Removing the regular advantage calculation in PPO #207

vwxyzjn opened this issue Jun 20, 2022 · 2 comments

Comments

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 20, 2022

Problem description.

The regular advantage calculation in PPO is a special case of the GAE advantage calculation when gae_lambda=1 - we empirically demonstrate this with the debugging output in the bottom. Based on this result, we should remove

cleanrl/cleanrl/ppo.py

Lines 232 to 242 in 94a685d

else:
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
next_return = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
next_return = returns[t + 1]
returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

Debugging output

(cleanrl-ghSZGHE3-py3.9) ➜  cleanrl git:(explain-non-modular) ✗ ipython -i ppo.py  
Python 3.9.5 (default, Jul 19 2021, 13:27:26) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.4.0 -- An enhanced Interactive Python. Type '?' for help.
/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/gym/utils/passive_env_checker.py:97: UserWarning: WARN: We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html
  logger.warn(
/home/costa/.cache/pypoetry/virtualenvs/cleanrl-ghSZGHE3-py3.9/lib/python3.9/site-packages/gym/core.py:200: DeprecationWarning: WARN: Function `env.seed(seed)` is marked as deprecated and will be removed in the future. Please use `env.reset(seed=seed)` instead.
  deprecation(
global_step=36, episodic_return=9.0
global_step=52, episodic_return=13.0
global_step=100, episodic_return=25.0
global_step=112, episodic_return=19.0
global_step=128, episodic_return=32.0
global_step=144, episodic_return=11.0
global_step=152, episodic_return=13.0
global_step=176, episodic_return=12.0
global_step=196, episodic_return=11.0
global_step=228, episodic_return=13.0
global_step=260, episodic_return=16.0
global_step=296, episodic_return=46.0
global_step=300, episodic_return=39.0
global_step=312, episodic_return=13.0
global_step=360, episodic_return=15.0
global_step=388, episodic_return=40.0
global_step=400, episodic_return=22.0
global_step=408, episodic_return=28.0
global_step=440, episodic_return=13.0
global_step=460, episodic_return=15.0
global_step=484, episodic_return=31.0
global_step=500, episodic_return=23.0
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/Documents/go/src/github.com/cleanrl/cleanrl/ppo.py:243, in <module>
    241             returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
    242         advantages = returns - values
--> 243 raise
    244 # flatten the batch
    245 b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)

RuntimeError: No active exception to reraise

In [1]:                 returns = torch.zeros_like(rewards).to(device)
   ...:                 for t in reversed(range(args.num_steps)):
   ...:                     if t == args.num_steps - 1:
   ...:                         nextnonterminal = 1.0 - next_done
   ...:                         next_return = next_value
   ...:                     else:
   ...:                         nextnonterminal = 1.0 - dones[t + 1]
   ...:                         next_return = returns[t + 1]
   ...:                     returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
   ...:                 advantages = returns - values

In [2]: returns.sum()
Out[2]: tensor(6017.7227, device='cuda:0')

In [3]: advantages.sum()
Out[3]: tensor(6005.0435, device='cuda:0')

In [4]:                 advantages = torch.zeros_like(rewards).to(device)
   ...:                 lastgaelam = 0
   ...:                 for t in reversed(range(args.num_steps)):
   ...:                     if t == args.num_steps - 1:
   ...:                         nextnonterminal = 1.0 - next_done
   ...:                         nextvalues = next_value
   ...:                     else:
   ...:                         nextnonterminal = 1.0 - dones[t + 1]
   ...:                         nextvalues = values[t + 1]
   ...:                     delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
   ...:                     advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
   ...:                 returns = advantages + values

In [5]: returns.sum()
Out[5]: tensor(4088.1948, device='cuda:0')

In [6]: advantages.sum()
Out[6]: tensor(4075.5161, device='cuda:0')

In [7]: args.gae_lambda
Out[7]: 0.95

In [8]: args.gae_lambda = 1

In [9]:                 advantages = torch.zeros_like(rewards).to(device)
   ...:                 lastgaelam = 0
   ...:                 for t in reversed(range(args.num_steps)):
   ...:                     if t == args.num_steps - 1:
   ...:                         nextnonterminal = 1.0 - next_done
   ...:                         nextvalues = next_value
   ...:                     else:
   ...:                         nextnonterminal = 1.0 - dones[t + 1]
   ...:                         nextvalues = values[t + 1]
   ...:                     delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
   ...:                     advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
   ...:                 returns = advantages + values

In [10]: returns.sum()
Out[10]: tensor(6017.7227, device='cuda:0')

In [11]: advantages.sum()
Out[11]: tensor(6005.0435, device='cuda:0')
@vwxyzjn vwxyzjn mentioned this issue Jun 20, 2022
5 tasks
@vwxyzjn vwxyzjn changed the title Removing the regular advantage calculation Removing the regular advantage calculation in PPO Jun 20, 2022
@Howuhh
Copy link
Contributor

Howuhh commented Jun 20, 2022

Yup, this can be also easily verified from the GAE computation formula, with gae_lambda=1 all terms besides needed for n_step_returns will cancel out

Repository owner deleted a comment from 3050821417 Aug 31, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
To resolve issue vwxyzjn#207 in cleanrl, extra advantage code not needed
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
To resolve issue vwxyzjn#207 in cleanrl, extra advantage calc code unnecessary
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
Updated to resolve issue vwxyzjn#207, unncessary additional advantage calc code
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
Updated to resolve issue vwxyzjn#207, unnecessary additional advantage calc code for ppo implementations
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
Updated to resolve issue vwxyzjn#207
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
bragajj added a commit to bragajj/cleanrl that referenced this issue Oct 3, 2022
vwxyzjn pushed a commit that referenced this issue Oct 4, 2022
* Update ppo.py

To resolve issue #207 in cleanrl, extra advantage code not needed

* Update ppo_atari.py

To resolve issue #207 in cleanrl, extra advantage calc code unnecessary

* Update ppo_atari_envpool.py

Updated to resolve issue #207, unncessary additional advantage calc code

* Update ppo_continuous_action.py

Updated to resolve issue #207

* Update ppo_atari_lstm.py

Updated to resolve issue #207, unnecessary additional advantage calc code for ppo implementations

* Update ppo_pettingzoo_ma_atari.py

Updated to resolve issue #207

* Update ppo_procgen.py

Updated to resolve issue #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* GAE revisions #207

* Update ppo_rnd_envpool.py

Fixed styling of lines 432-436
@vwxyzjn
Copy link
Owner Author

vwxyzjn commented Oct 4, 2022

Closed by #287

@vwxyzjn vwxyzjn closed this as completed Oct 4, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants