-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
fixed new gym API related to step() and reset() #87
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,27 @@ | |
import itertools | ||
import os | ||
import sys | ||
|
||
import gym | ||
from gym.core import ObsType, RenderFrame | ||
from gym.spaces import Box | ||
from gym.spaces import Discrete | ||
import numpy as np | ||
from ._rom import ROM | ||
from ._image_viewer import ImageViewer | ||
|
||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Dict, | ||
Generic, | ||
List, | ||
Optional, | ||
SupportsFloat, | ||
Tuple, | ||
TypeVar, | ||
Union, | ||
) | ||
|
||
# the path to the directory this file is in | ||
_MODULE_PATH = os.path.dirname(__file__) | ||
|
@@ -24,7 +38,6 @@ | |
except IndexError: | ||
raise OSError('missing static lib_nes_env*.so library!') | ||
|
||
|
||
# setup the argument and return types for Width | ||
_LIB.Width.argtypes = None | ||
_LIB.Width.restype = ctypes.c_uint | ||
|
@@ -59,7 +72,6 @@ | |
_LIB.Close.argtypes = [ctypes.c_void_p] | ||
_LIB.Close.restype = None | ||
|
||
|
||
# height in pixels of the NES screen | ||
SCREEN_HEIGHT = _LIB.Height() | ||
# width in pixels of the NES screen | ||
|
@@ -71,11 +83,9 @@ | |
# create a type for the screen tensor matrix from C++ | ||
SCREEN_TENSOR = ctypes.c_byte * int(np.prod(SCREEN_SHAPE_32_BIT)) | ||
|
||
|
||
# create a type for the RAM vector from C++ | ||
RAM_VECTOR = ctypes.c_byte * 0x800 | ||
|
||
|
||
# create a type for the controller buffers from C++ | ||
CONTROLLER_VECTOR = ctypes.c_byte * 1 | ||
|
||
|
@@ -94,10 +104,10 @@ class NESEnv(gym.Env): | |
|
||
# observation space for the environment is static across all instances | ||
observation_space = Box( | ||
low=0, | ||
high=255, | ||
shape=SCREEN_SHAPE_24_BIT, | ||
dtype=np.uint8 | ||
low=0, | ||
high=255, | ||
shape=SCREEN_SHAPE_24_BIT, | ||
dtype=np.uint8 | ||
) | ||
|
||
# action space is a bitmap of button press values for the 8 NES buttons | ||
|
@@ -145,6 +155,8 @@ def __init__(self, rom_path): | |
self._has_backup = False | ||
# setup a done flag | ||
self.done = True | ||
# truncated | ||
self.truncated = False | ||
# setup the controllers, screen, and RAM buffers | ||
self.controllers = [self._controller_buffer(port) for port in range(2)] | ||
self.screen = self._screen_buffer() | ||
|
@@ -243,7 +255,7 @@ def seed(self, seed=None): | |
# return the list of seeds used by RNG(s) in the environment | ||
return [seed] | ||
|
||
def reset(self, seed=None, options=None, return_info=None): | ||
def reset(self, seed=None, options=None, return_info=None) -> Tuple[ObsType, dict]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove |
||
""" | ||
Reset the state of the environment and returns an initial observation. | ||
|
||
|
@@ -253,7 +265,9 @@ def reset(self, seed=None, options=None, return_info=None): | |
return_info (any): unused | ||
|
||
Returns: | ||
state (np.ndarray): next frame as a result of the given action | ||
a tuple | ||
state (np.ndarray): next frame as a result of the given action | ||
info dict: Return the info after a step occurs | ||
|
||
""" | ||
# Set the seed. | ||
|
@@ -270,13 +284,13 @@ def reset(self, seed=None, options=None, return_info=None): | |
# set the done flag to false | ||
self.done = False | ||
# return the screen from the emulator | ||
return self.screen | ||
return self.screen, self._get_info() | ||
|
||
def _did_reset(self): | ||
"""Handle any RAM hacking after a reset occurs.""" | ||
pass | ||
|
||
def step(self, action): | ||
def step(self, action) -> Tuple[ObsType, float, bool, bool, dict]: | ||
""" | ||
Run one frame of the NES and return the relevant observation data. | ||
|
||
|
@@ -304,6 +318,7 @@ def step(self, action): | |
self.done = bool(self._get_done()) | ||
# get the info for this step | ||
info = self._get_info() | ||
self.truncated = self._get_truncated() | ||
# call the after step callback | ||
self._did_step(self.done) | ||
# bound the reward in [min, max] | ||
|
@@ -312,7 +327,7 @@ def step(self, action): | |
elif reward > self.reward_range[1]: | ||
reward = self.reward_range[1] | ||
# return the screen from the emulator and other relevant data | ||
return self.screen, reward, self.done, info | ||
return self.screen, reward, self.done, self.truncated, info | ||
|
||
def _get_reward(self): | ||
"""Return the reward after a step occurs.""" | ||
|
@@ -322,6 +337,10 @@ def _get_done(self): | |
"""Return True if the episode is over, False otherwise.""" | ||
return False | ||
|
||
def _get_truncated(self): | ||
"""Return True if truncated """ | ||
return False | ||
|
||
def _get_info(self): | ||
"""Return the info after a step occurs.""" | ||
return {} | ||
|
@@ -352,7 +371,7 @@ def close(self): | |
if self.viewer is not None: | ||
self.viewer.close() | ||
|
||
def render(self, mode='human'): | ||
def render(self, mode='human') -> Optional[Union[RenderFrame, List[RenderFrame]]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove mode parameter and add |
||
""" | ||
Render the environment. | ||
|
||
|
@@ -378,9 +397,9 @@ def render(self, mode='human'): | |
caption = self.spec.id | ||
# create the ImageViewer to display frames | ||
self.viewer = ImageViewer( | ||
caption=caption, | ||
height=SCREEN_HEIGHT, | ||
width=SCREEN_WIDTH, | ||
caption=caption, | ||
height=SCREEN_HEIGHT, | ||
width=SCREEN_WIDTH, | ||
) | ||
# show the screen on the image viewer | ||
self.viewer.show(self.screen) | ||
|
@@ -401,7 +420,7 @@ def get_keys_to_action(self): | |
ord('a'), # left | ||
ord('s'), # down | ||
ord('w'), # up | ||
ord('\r'), # start | ||
ord('\r'), # start | ||
ord(' '), # select | ||
ord('p'), # B | ||
ord('o'), # A | ||
|
@@ -427,4 +446,4 @@ def get_action_meanings(self): | |
|
||
|
||
# explicitly define the outward facing API of this module | ||
__all__ = [NESEnv.__name__] | ||
__all__ = [NESEnv.__name__] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,9 +24,9 @@ def play(steps): | |
done = True | ||
for _ in range(steps): | ||
if done: | ||
_ = env.reset() | ||
_, _ = env.reset() | ||
action = env.action_space.sample() | ||
_, _, done, _ = env.step(action) | ||
_, _, done, _, _ = env.step(action) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace done with terminated and truncated as in the comment before |
||
# close the environment | ||
env.close() | ||
|
||
|
@@ -45,7 +45,7 @@ class ShouldMakeMultipleEnvironmentsParallel(object): | |
|
||
def test(self): | ||
procs = [None] * self.num_execs | ||
args = (self.steps, ) | ||
args = (self.steps,) | ||
# spawn the parallel instances | ||
for idx in range(self.num_execs): | ||
procs[idx] = self.parallel_initializer(target=play, args=args) | ||
|
@@ -82,6 +82,6 @@ def test(self): | |
for _ in range(self.steps): | ||
for idx in range(self.num_envs): | ||
if dones[idx]: | ||
_ = envs[idx].reset() | ||
_, _ = envs[idx].reset() | ||
action = envs[idx].action_space.sample() | ||
_, _, dones[idx], _ = envs[idx].step(action) | ||
_, _, dones[idx], _, _ = envs[idx].step(action) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,20 +79,21 @@ def test(self): | |
for _ in range(500): | ||
if done: | ||
# reset the environment and check the output value | ||
state = env.reset() | ||
state, _ = env.reset() | ||
self.assertIsInstance(state, np.ndarray) | ||
# sample a random action and check it | ||
action = env.action_space.sample() | ||
self.assertIsInstance(action, int) | ||
# take a step and check the outputs | ||
output = env.step(action) | ||
self.assertIsInstance(output, tuple) | ||
self.assertEqual(4, len(output)) | ||
self.assertEqual(5, len(output)) | ||
# check each output | ||
state, reward, done, info = output | ||
state, reward, done, truncated, info = output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.assertIsInstance(state, np.ndarray) | ||
self.assertIsInstance(reward, float) | ||
self.assertIsInstance(done, bool) | ||
self.assertIsInstance(truncated, bool) | ||
self.assertIsInstance(info, dict) | ||
# check the render output | ||
render = env.render('rgb_array') | ||
|
@@ -108,9 +109,9 @@ def test(self): | |
|
||
for _ in range(250): | ||
if done: | ||
state = env.reset() | ||
state, _ = env.reset() | ||
done = False | ||
state, _, done, _ = env.step(0) | ||
state, _, done, _, _ = env.step(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on terminated and truncated |
||
|
||
backup = state.copy() | ||
|
||
|
@@ -120,9 +121,9 @@ def test(self): | |
if done: | ||
state = env.reset() | ||
done = False | ||
state, _, done, _ = env.step(0) | ||
state, _, done, _, _ = env.step(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on terminated and truncated |
||
|
||
self.assertFalse(np.array_equal(backup, state)) | ||
env._restore() | ||
self.assertTrue(np.array_equal(backup, env.screen)) | ||
env.close() | ||
env.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,9 @@ | |
try: | ||
for _ in tqdm.tqdm(range(5000)): | ||
if done: | ||
state = env.reset() | ||
state, _ = env.reset() | ||
done = False | ||
else: | ||
state, reward, done, info = env.step(env.action_space.sample()) | ||
state, reward, done, truncated, info = env.step(env.action_space.sample()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on terminated and truncated |
||
except KeyboardInterrupt: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
|
||
setup( | ||
name='nes_py', | ||
version='8.2.1', | ||
version='8.2.2', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor point, but I would make this a minor or major release due to the significant code changes |
||
description='An NES Emulator and OpenAI Gym interface', | ||
long_description=README, | ||
long_description_content_type='text/markdown', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with