diff --git a/week5_policy_based/atari_util.py b/week5_policy_based/atari_util.py index 736db61c2..e888e36ef 100644 --- a/week5_policy_based/atari_util.py +++ b/week5_policy_based/atari_util.py @@ -7,7 +7,7 @@ class PreprocessAtari(Wrapper): def __init__(self, env, height=42, width=42, color=False, - crop=lambda img: img, n_frames=4, dim_order='theano'): + crop=lambda img: img, n_frames=4, dim_order='theano', reward_scale=1,): """A gym wrapper that reshapes, crops and scales image into the desired shapes""" super(PreprocessAtari, self).__init__(env) assert dim_order in ('theano', 'tensorflow') @@ -16,6 +16,7 @@ def __init__(self, env, height=42, width=42, color=False, self.color = color self.dim_order = dim_order + self.reward_scale = reward_scale n_channels = (3 * n_frames) if color else n_frames obs_shape = \ [n_channels, height, width] \ @@ -34,8 +35,8 @@ def step(self, action): """plays breakout for 1 step, returns frame buffer""" new_img, r, done, info = self.env.step(action) self.update_buffer(new_img) - return self.framebuffer, r, done, info + return self.framebuffer, r * self.reward_scale, done, info ### image processing ### def update_buffer(self, img):