Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Jax Cliffwalking Env #407

Merged

Conversation

balisujohn
Copy link
Contributor

@balisujohn balisujohn commented Mar 26, 2023

Description

This adds a CliffWalkingJaxEnv to the tabular environment suite, as well as the underlying FuncEnv subclass CliffWalkingFunctional.

CliffWalkingJaxEnv is registered with the name "Jax-CliffWalking-v0" This PR still needs a bit of polish work, but you can try it out for yourself in GUI mode with

python3 ./gymnasium/envs/tabular/cliffwalking.py

image

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@balisujohn balisujohn marked this pull request as draft March 26, 2023 06:35
gymnasium/envs/__init__.py Outdated Show resolved Hide resolved
gymnasium/envs/tabular/cliffwalking.py Outdated Show resolved Hide resolved
gymnasium/envs/tabular/cliffwalking.py Outdated Show resolved Hide resolved
gymnasium/envs/tabular/cliffwalking.py Outdated Show resolved Hide resolved
@balisujohn balisujohn marked this pull request as ready for review April 1, 2023 20:33
@@ -17,7 +17,7 @@
from gymnasium.wrappers import HumanRendering


RenderStateType = Tuple["pygame.Surface"] # type: ignore # noqa: F821
RenderStateType = Tuple["pygame.Surface", Tuple[int, int], int, Tuple[int, int], "numpy.ndarray", Tuple["pygame.Surface", "pygame.Surface", "pygame.Surface", "pygame.Surface"], "pygame.Surface", "pygame.Surface", Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], Tuple[str, str], Tuple["pygame.surface", "pygame.surface"], "pygame.surface"] # type: ignore # noqa: F821
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want the RenderStateType to be a NamedTuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, and I think close to being merged. Thanks for your hard work.

I was able to get the type issue solved using from __future__ import annotations and if TYPE_CHECKING: import pygame

from __future__ import annotations

from os import path
from typing import TYPE_CHECKING, NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey

from gymnasium import spaces
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
from gymnasium.utils import EzPickle
from gymnasium.wrappers import HumanRendering


if TYPE_CHECKING:
    import pygame


class RenderStateType(NamedTuple):
    """A named tuple which contains the full render state of the Cliffwalking Env. This is static during the episode."""

    screen: pygame.surface
    shape: tuple[int, int]
    nS: int
    cell_size: tuple[int, int]
    cliff: np.ndarray
    elf_images: tuple[pygame.Surface, pygame.Surface, pygame.Surface, pygame.Surface]
    start_img: pygame.Surface
    goal_img: pygame.Surface
    bg_imgs: tuple[str, str]
    mountain_bg_img: tuple[pygame.Surface, pygame.Surface]
    near_cliff_imgs: tuple[str, str]
    near_cliff_img: tuple[pygame.Surface, pygame.Surface]
    cliff_img: pygame.Surface

@balisujohn
Copy link
Contributor Author

added that typing fix

@pseudo-rnd-thoughts pseudo-rnd-thoughts merged commit 51a735e into Farama-Foundation:main Apr 9, 2023
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants