Skip to content

Commit

Permalink
Merge pull request #87 from kayendns/airhockey-off-screen-rendering
Browse files Browse the repository at this point in the history
Implemented render_mode 'rgb_array' for AirHockeyEnv
  • Loading branch information
Onur4229 authored Dec 22, 2023
2 parents fa54dbb + aeecf65 commit aa652e3
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from mushroom_rl.core import Environment

class AirHockeyEnv(Environment):
metadata = {"render_modes": ["human"], "render_fps": 50}
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}

def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs):
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs):
"""
Environment Constructor
Expand Down Expand Up @@ -42,6 +42,17 @@ def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwa
if env_mode == "tournament" and type(interpolation_order) != tuple:
interpolation_order = (interpolation_order, interpolation_order)

self.render_mode = render_mode
self.render_human_active = False

# Determine headless mode based on render_mode
headless = self.render_mode == 'rgb_array'

# Prepare viewer_params
viewer_params = kwargs.get('viewer_params', {})
viewer_params.update({'headless': headless, 'width': width, 'height': height})
kwargs['viewer_params'] = viewer_params

self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
self.env_name = env_mode
self.env_info = self.base_env.env_info
Expand Down Expand Up @@ -89,9 +100,6 @@ def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwa
self.env_info['constraints'] = constraint_list
self.env_info['env_name'] = self.env_name

self.render_mode = render_mode
self.render_human_active = False

super().__init__(self.base_env.info)

def step(self, action):
Expand Down Expand Up @@ -119,15 +127,21 @@ def step(self, action):

if self.env_info['env_name'] == "tournament":
obs = np.array(np.split(obs, 2))

if self.render_human_active:
self.base_env.render()

return obs, reward, done, False, info

def render(self):
self.render_human_active = True

if self.render_mode == 'rgb_array':
return self.base_env.render(record = True)
elif self.render_mode == 'human':
self.render_human_active = True
self.base_env.render()
else:
raise ValueError(f"Unsupported render mode: '{self.render_mode}'")

def reset(self, seed=None, options={}):
self.base_env.seed(seed)
obs = self.base_env.reset()
Expand Down Expand Up @@ -185,4 +199,4 @@ def close(self):
J = 0.
gamma = 1.
steps = 0
env.reset()
env.reset()

0 comments on commit aa652e3

Please # to comment.