Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fixed new gym API related to step() and reset() #87

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nes_py/app/play_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def play_human(env: gym.Env, callback=None):
# reset if the environment is done
if done:
done = False
state = env.reset()
state, _ = env.reset()
viewer.show(env.unwrapped.screen)
# unwrap the action based on pressed relevant keys
action = keys_to_action.get(viewer.pressed_keys, _NOP)
next_state, reward, done, _ = env.step(action)
next_state, reward, done, truncated, _ = env.step(action)
viewer.show(env.unwrapped.screen)
# pass the observation data through the callback
if callback is not None:
callback(state, action, reward, done, next_state)
callback(state, action, reward, done, truncated, next_state)
state = next_state
# shutdown if the escape key is pressed
if viewer.is_escape_pressed:
Expand Down
4 changes: 2 additions & 2 deletions nes_py/app/play_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def play_random(env, steps):
progress = tqdm(range(steps))
for _ in progress:
if done:
_ = env.reset()
_, _ = env.reset()
action = env.action_space.sample()
_, reward, done, info = env.step(action)
_, reward, done, _, info = env.step(action)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with

_, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated

progress.set_postfix(reward=reward, info=info)
env.render()
except KeyboardInterrupt:
Expand Down
57 changes: 38 additions & 19 deletions nes_py/nes_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@
import itertools
import os
import sys

import gym
from gym.core import ObsType, RenderFrame
from gym.spaces import Box
from gym.spaces import Discrete
import numpy as np
from ._rom import ROM
from ._image_viewer import ImageViewer

from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
SupportsFloat,
Tuple,
TypeVar,
Union,
)

# the path to the directory this file is in
_MODULE_PATH = os.path.dirname(__file__)
Expand All @@ -24,7 +38,6 @@
except IndexError:
raise OSError('missing static lib_nes_env*.so library!')


# setup the argument and return types for Width
_LIB.Width.argtypes = None
_LIB.Width.restype = ctypes.c_uint
Expand Down Expand Up @@ -59,7 +72,6 @@
_LIB.Close.argtypes = [ctypes.c_void_p]
_LIB.Close.restype = None


# height in pixels of the NES screen
SCREEN_HEIGHT = _LIB.Height()
# width in pixels of the NES screen
Expand All @@ -71,11 +83,9 @@
# create a type for the screen tensor matrix from C++
SCREEN_TENSOR = ctypes.c_byte * int(np.prod(SCREEN_SHAPE_32_BIT))


# create a type for the RAM vector from C++
RAM_VECTOR = ctypes.c_byte * 0x800


# create a type for the controller buffers from C++
CONTROLLER_VECTOR = ctypes.c_byte * 1

Expand All @@ -94,10 +104,10 @@ class NESEnv(gym.Env):

# observation space for the environment is static across all instances
observation_space = Box(
low=0,
high=255,
shape=SCREEN_SHAPE_24_BIT,
dtype=np.uint8
low=0,
high=255,
shape=SCREEN_SHAPE_24_BIT,
dtype=np.uint8
)

# action space is a bitmap of button press values for the 8 NES buttons
Expand Down Expand Up @@ -145,6 +155,8 @@ def __init__(self, rom_path):
self._has_backup = False
# setup a done flag
self.done = True
# truncated
self.truncated = False
# setup the controllers, screen, and RAM buffers
self.controllers = [self._controller_buffer(port) for port in range(2)]
self.screen = self._screen_buffer()
Expand Down Expand Up @@ -243,7 +255,7 @@ def seed(self, seed=None):
# return the list of seeds used by RNG(s) in the environment
return [seed]

def reset(self, seed=None, options=None, return_info=None):
def reset(self, seed=None, options=None, return_info=None) -> Tuple[ObsType, dict]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove return_info parameter

"""
Reset the state of the environment and returns an initial observation.

Expand All @@ -253,7 +265,9 @@ def reset(self, seed=None, options=None, return_info=None):
return_info (any): unused

Returns:
state (np.ndarray): next frame as a result of the given action
a tuple
state (np.ndarray): next frame as a result of the given action
info dict: Return the info after a step occurs

"""
# Set the seed.
Expand All @@ -270,13 +284,13 @@ def reset(self, seed=None, options=None, return_info=None):
# set the done flag to false
self.done = False
# return the screen from the emulator
return self.screen
return self.screen, self._get_info()

def _did_reset(self):
"""Handle any RAM hacking after a reset occurs."""
pass

def step(self, action):
def step(self, action) -> Tuple[ObsType, float, bool, bool, dict]:
"""
Run one frame of the NES and return the relevant observation data.

Expand Down Expand Up @@ -304,6 +318,7 @@ def step(self, action):
self.done = bool(self._get_done())
# get the info for this step
info = self._get_info()
self.truncated = self._get_truncated()
# call the after step callback
self._did_step(self.done)
# bound the reward in [min, max]
Expand All @@ -312,7 +327,7 @@ def step(self, action):
elif reward > self.reward_range[1]:
reward = self.reward_range[1]
# return the screen from the emulator and other relevant data
return self.screen, reward, self.done, info
return self.screen, reward, self.done, self.truncated, info

def _get_reward(self):
"""Return the reward after a step occurs."""
Expand All @@ -322,6 +337,10 @@ def _get_done(self):
"""Return True if the episode is over, False otherwise."""
return False

def _get_truncated(self):
"""Return True if truncated """
return False

def _get_info(self):
"""Return the info after a step occurs."""
return {}
Expand Down Expand Up @@ -352,7 +371,7 @@ def close(self):
if self.viewer is not None:
self.viewer.close()

def render(self, mode='human'):
def render(self, mode='human') -> Optional[Union[RenderFrame, List[RenderFrame]]]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove mode parameter and add render_mode to __init__ for specifying the type of rendering

"""
Render the environment.

Expand All @@ -378,9 +397,9 @@ def render(self, mode='human'):
caption = self.spec.id
# create the ImageViewer to display frames
self.viewer = ImageViewer(
caption=caption,
height=SCREEN_HEIGHT,
width=SCREEN_WIDTH,
caption=caption,
height=SCREEN_HEIGHT,
width=SCREEN_WIDTH,
)
# show the screen on the image viewer
self.viewer.show(self.screen)
Expand All @@ -401,7 +420,7 @@ def get_keys_to_action(self):
ord('a'), # left
ord('s'), # down
ord('w'), # up
ord('\r'), # start
ord('\r'), # start
ord(' '), # select
ord('p'), # B
ord('o'), # A
Expand All @@ -427,4 +446,4 @@ def get_action_meanings(self):


# explicitly define the outward facing API of this module
__all__ = [NESEnv.__name__]
__all__ = [NESEnv.__name__]
10 changes: 5 additions & 5 deletions nes_py/tests/test_multiple_makes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def play(steps):
done = True
for _ in range(steps):
if done:
_ = env.reset()
_, _ = env.reset()
action = env.action_space.sample()
_, _, done, _ = env.step(action)
_, _, done, _, _ = env.step(action)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace done with terminated and truncated as in the comment before

# close the environment
env.close()

Expand All @@ -45,7 +45,7 @@ class ShouldMakeMultipleEnvironmentsParallel(object):

def test(self):
procs = [None] * self.num_execs
args = (self.steps, )
args = (self.steps,)
# spawn the parallel instances
for idx in range(self.num_execs):
procs[idx] = self.parallel_initializer(target=play, args=args)
Expand Down Expand Up @@ -82,6 +82,6 @@ def test(self):
for _ in range(self.steps):
for idx in range(self.num_envs):
if dones[idx]:
_ = envs[idx].reset()
_, _ = envs[idx].reset()
action = envs[idx].action_space.sample()
_, _, dones[idx], _ = envs[idx].step(action)
_, _, dones[idx], _, _ = envs[idx].step(action)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

15 changes: 8 additions & 7 deletions nes_py/tests/test_nes_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,21 @@ def test(self):
for _ in range(500):
if done:
# reset the environment and check the output value
state = env.reset()
state, _ = env.reset()
self.assertIsInstance(state, np.ndarray)
# sample a random action and check it
action = env.action_space.sample()
self.assertIsInstance(action, int)
# take a step and check the outputs
output = env.step(action)
self.assertIsInstance(output, tuple)
self.assertEqual(4, len(output))
self.assertEqual(5, len(output))
# check each output
state, reward, done, info = output
state, reward, done, truncated, info = output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

terminated, truncated, as terminated != done

self.assertIsInstance(state, np.ndarray)
self.assertIsInstance(reward, float)
self.assertIsInstance(done, bool)
self.assertIsInstance(truncated, bool)
self.assertIsInstance(info, dict)
# check the render output
render = env.render('rgb_array')
Expand All @@ -108,9 +109,9 @@ def test(self):

for _ in range(250):
if done:
state = env.reset()
state, _ = env.reset()
done = False
state, _, done, _ = env.step(0)
state, _, done, _, _ = env.step(0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on terminated and truncated


backup = state.copy()

Expand All @@ -120,9 +121,9 @@ def test(self):
if done:
state = env.reset()
done = False
state, _, done, _ = env.step(0)
state, _, done, _, _ = env.step(0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on terminated and truncated


self.assertFalse(np.array_equal(backup, state))
env._restore()
self.assertTrue(np.array_equal(backup, env.screen))
env.close()
env.close()
4 changes: 2 additions & 2 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
try:
for _ in tqdm.tqdm(range(5000)):
if done:
state = env.reset()
state, _ = env.reset()
done = False
else:
state, reward, done, info = env.step(env.action_space.sample())
state, reward, done, truncated, info = env.step(env.action_space.sample())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on terminated and truncated

except KeyboardInterrupt:
pass
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

setup(
name='nes_py',
version='8.2.1',
version='8.2.2',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point, but I would make this a minor or major release due to the significant code changes

description='An NES Emulator and OpenAI Gym interface',
long_description=README,
long_description_content_type='text/markdown',
Expand Down