From d71a13588266256a4c900b5e0d72d10785816c3a Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 19 Dec 2022 12:53:06 +0000 Subject: [PATCH] Replace `np.bool8` with `np.bool_` for numpy 1.24 (#221) --- gymnasium/utils/passive_env_checker.py | 6 +++--- tests/spaces/test_box.py | 2 +- tests/spaces/test_spaces.py | 6 +++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index fb426b4e5..46d220ae0 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -238,7 +238,7 @@ def env_step_passive_checker(env, action): ) obs, reward, done, info = result - if not isinstance(done, (bool, np.bool8)): + if not isinstance(done, (bool, np.bool_)): logger.warn( f"Expects `done` signal to be a boolean, actual type: {type(done)}" ) @@ -246,11 +246,11 @@ def env_step_passive_checker(env, action): obs, reward, terminated, truncated, info = result # np.bool is actual python bool not np boolean type, therefore bool_ or bool8 - if not isinstance(terminated, (bool, np.bool8)): + if not isinstance(terminated, (bool, np.bool_)): logger.warn( f"Expects `terminated` signal to be a boolean, actual type: {type(terminated)}" ) - if not isinstance(truncated, (bool, np.bool8)): + if not isinstance(truncated, (bool, np.bool_)): logger.warn( f"Expects `truncated` signal to be a boolean, actual type: {type(truncated)}" ) diff --git a/tests/spaces/test_box.py b/tests/spaces/test_box.py index 524718e27..24a07d2e8 100644 --- a/tests/spaces/test_box.py +++ b/tests/spaces/test_box.py @@ -51,7 +51,7 @@ def test_shape_inference(box, expected_shape): (np.inf, True), (np.nan, True), # This is a weird case that we allow (True, False), - (np.bool8(True), False), + (np.bool_(True), False), (1 + 1j, False), (np.complex128(1 + 1j), False), ("string", False), diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 910bc9723..932601e19 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -447,7 +447,11 @@ def test_sample_contains(space): assert space.contains(sample) for other_space in TESTING_SPACES: - assert isinstance(space.contains(other_space.sample()), bool) + sample = other_space.sample() + space_contains = other_space.contains(sample) + assert isinstance( + space_contains, bool + ), f"{space_contains}, {type(space_contains)}, {space}, {other_space}, {sample}" @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS)