From a4206eeb40474a5bc431cd2e44ccd620e9996d9c Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 4 Apr 2024 17:56:13 +0200 Subject: [PATCH 01/11] Allow two (same/different) Batch objs to be tested for equality --- test/base/test_batch.py | 42 +++++++++++++++++++++++++++++++++++++++++ tianshou/data/batch.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index f11a8d60e..9f0775c26 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -565,6 +565,47 @@ def test_batch_standard_compatibility() -> None: Batch()[0] +def test_batch_eq() -> None: + # Different keys + batch1 = Batch(a=[1, 2], b=[100, 50]) + batch2 = Batch(b=[1, 2], c=[100, 50]) + assert batch1 != batch2, "Keys are not matching." + + # Missing keys + batch1 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2.pop("b") + assert batch1 != batch2, "Keys are not matching." + + # Different types for the same key + batch1 = Batch(a=[1, 2, 3], b=[4, 5]) + batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5])) + assert batch1 != batch2, "Objects have different types" + + # Different array types for the same key + batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5])) + batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5])) + assert batch1 != batch2, "Objects have different types" + + # Nested Batch objects with values + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5]) + assert batch1 != batch2, "Nested batches have different values." + + # Arrays with different shapes or values + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) + assert batch1 != batch2, "Nested objects have different lengths" + + # Same slice from the same batch + batch1 = Batch(a=[1, 2, 3]) + assert batch1[:2] == batch1[:2], "Batch slice should be the same" + + # Same slice from the same batch with ellipsis and slice + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) + assert batch1[..., 1:] == batch1[..., 1:], "Batch slice should be the same" + + if __name__ == "__main__": test_batch() test_batch_over_batch() @@ -576,3 +617,4 @@ def test_batch_standard_compatibility() -> None: test_batch_cat_and_stack() test_batch_copy() test_batch_empty() + test_batch_eq() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 5c7fa036e..24922b5d2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -268,6 +268,9 @@ def __repr__(self) -> str: def __iter__(self) -> Iterator[Self]: ... + def __eq__(self, other: Self) -> bool: # type: ignore + ... + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -500,6 +503,35 @@ def __getitem__(self, index: str | IndexType) -> Any: return new_batch raise IndexError("Cannot access item from empty Batch object.") + def __eq__(self, other: Self) -> bool: # type: ignore + this_dict = self.__dict__ + other_dict = other.__dict__ + + if len(this_dict) != len(other_dict): + return False + for batch_key, obs in this_dict.items(): + if batch_key not in other_dict: + return False + + other_val = other.__dict__[batch_key] + + if batch_key in other_dict: + if isinstance(obs, Batch) and isinstance(other_val, Batch): + if not obs == other_val: + return False + elif isinstance(obs, np.ndarray) and isinstance(other_val, np.ndarray): + if not np.all(np.equal(obs.shape, other_val.shape)): + return False + if not np.all(np.equal(obs, other_val)): + return False + elif isinstance(obs, torch.Tensor) and isinstance(other_val, torch.Tensor): + if not torch.equal(obs, other_val): + return False + else: + return False + + return True + def __iter__(self) -> Iterator[Self]: # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea if len(self.__dict__) == 0: From 093956358098a1389c3442830f5660086f21ef1a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 7 Apr 2024 16:14:29 +0200 Subject: [PATCH 02/11] Refactor Batch equality tests for better readability * Introduce a new test class named `TestBatchEquality`. * Got rid of comments and assert messages as the test method names are expressive enough. --- test/base/test_batch.py | 87 ++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 9f0775c26..408f5f0d5 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -565,45 +565,53 @@ def test_batch_standard_compatibility() -> None: Batch()[0] -def test_batch_eq() -> None: - # Different keys - batch1 = Batch(a=[1, 2], b=[100, 50]) - batch2 = Batch(b=[1, 2], c=[100, 50]) - assert batch1 != batch2, "Keys are not matching." - - # Missing keys - batch1 = Batch(a=[1, 2], b=[2, 3, 4]) - batch2 = Batch(a=[1, 2], b=[2, 3, 4]) - batch2.pop("b") - assert batch1 != batch2, "Keys are not matching." - - # Different types for the same key - batch1 = Batch(a=[1, 2, 3], b=[4, 5]) - batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5])) - assert batch1 != batch2, "Objects have different types" - - # Different array types for the same key - batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5])) - batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5])) - assert batch1 != batch2, "Objects have different types" - - # Nested Batch objects with values - batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) - batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5]) - assert batch1 != batch2, "Nested batches have different values." - - # Arrays with different shapes or values - batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) - batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) - assert batch1 != batch2, "Nested objects have different lengths" - - # Same slice from the same batch - batch1 = Batch(a=[1, 2, 3]) - assert batch1[:2] == batch1[:2], "Batch slice should be the same" - - # Same slice from the same batch with ellipsis and slice - batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) - assert batch1[..., 1:] == batch1[..., 1:], "Batch slice should be the same" +class TestBatchEquality: + @staticmethod + def test_keys_different() -> None: + batch1 = Batch(a=[1, 2], b=[100, 50]) + batch2 = Batch(b=[1, 2], c=[100, 50]) + assert batch1 != batch2 + + @staticmethod + def test_keys_missing() -> None: + batch1 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2 = Batch(a=[1, 2], b=[2, 3, 4]) + batch2.pop("b") + assert batch1 != batch2 + + @staticmethod + def test_types_keys_different() -> None: + batch1 = Batch(a=[1, 2, 3], b=[4, 5]) + batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5])) + assert batch1 != batch2 + + @staticmethod + def test_array_types_different() -> None: + batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5])) + batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5])) + assert batch1 != batch2 + + @staticmethod + def test_nested_values_different() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5]) + assert batch1 != batch2 + + @staticmethod + def test_nested_shapes_different() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5]) + batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5]) + assert batch1 != batch2 + + @staticmethod + def test_slice_equal() -> None: + batch1 = Batch(a=[1, 2, 3]) + assert batch1[:2] == batch1[:2] + + @staticmethod + def test_slice_ellipsis_equal() -> None: + batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) + assert batch1[..., 1:] == batch1[..., 1:] if __name__ == "__main__": @@ -617,4 +625,3 @@ def test_batch_eq() -> None: test_batch_cat_and_stack() test_batch_copy() test_batch_empty() - test_batch_eq() From 6bcd027b87627600668dcbe5328fcfb2a49324eb Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 7 Apr 2024 16:27:25 +0200 Subject: [PATCH 03/11] Add more simple tests to TestBatchEquality test suite --- test/base/test_batch.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 408f5f0d5..777ef9636 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -613,6 +613,22 @@ def test_slice_ellipsis_equal() -> None: batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000]) assert batch1[..., 1:] == batch1[..., 1:] + @staticmethod + def test_empty_batches() -> None: + assert Batch() == Batch() + + @staticmethod + def test_different_order_keys() -> None: + assert Batch(a=1, b=2) == Batch(b=2, a=1) + + @staticmethod + def test_tuple_and_list_types() -> None: + assert Batch(a=(1, 2)) == Batch(a=[1, 2]) + + @staticmethod + def test_subbatch_dict_and_batch_types() -> None: + assert Batch(a={"x": 1}) == Batch(a=Batch(x=1)) + if __name__ == "__main__": test_batch() From b99973c5d9bb15ab92430883859c43efb8fb639c Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 7 Apr 2024 16:29:38 +0200 Subject: [PATCH 04/11] Allow user to use any object for Batch equality * Restrict to Batch in method --- tianshou/data/batch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 24922b5d2..e66fbcc6e 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -268,7 +268,7 @@ def __repr__(self) -> str: def __iter__(self) -> Iterator[Self]: ... - def __eq__(self, other: Self) -> bool: # type: ignore + def __eq__(self, other: Any) -> bool: ... def to_numpy(self) -> None: @@ -503,7 +503,10 @@ def __getitem__(self, index: str | IndexType) -> Any: return new_batch raise IndexError("Cannot access item from empty Batch object.") - def __eq__(self, other: Self) -> bool: # type: ignore + def __eq__(self, other: Any) -> bool: + if not isinstance(other, self.__class__): + return False + this_dict = self.__dict__ other_dict = other.__dict__ From 68050c026ec6b8152e23eaeb1ef8230db32d524a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:46:55 +0200 Subject: [PATCH 05/11] Extend Batch.to_dict(recurse=False): * Add tests for `Batch.to_dict()` * Add deepdiff library (for testing equality of dicts in the tests mentioned above) * Update poetry.toml and poetry.lock --- poetry.lock | 35 +++++++++++++++++-- pyproject.toml | 1 + test/base/test_batch.py | 75 ++++++++++++++++++++++++++++++++++++++++- tianshou/data/batch.py | 8 ++--- 4 files changed, 112 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index a6fbf3229..56230ab4f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -903,6 +903,24 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "deepdiff" +version = "7.0.1" +description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." +optional = false +python-versions = ">=3.8" +files = [ + {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, + {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, +] + +[package.dependencies] +ordered-set = ">=4.1.0,<4.2.0" + +[package.extras] +cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"] +optimize = ["orjson"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -3239,7 +3257,6 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] @@ -3349,6 +3366,20 @@ numpy = ["numpy"] test = ["pytest", "pytest-cov", "pytest-xdist"] torch = ["torch"] +[[package]] +name = "ordered-set" +version = "4.1.0" +description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" +optional = false +python-versions = ">=3.7" +files = [ + {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, + {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, +] + +[package.extras] +dev = ["black", "mypy", "pytest"] + [[package]] name = "overrides" version = "7.4.0" @@ -6223,4 +6254,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b" +content-hash = "a7aa80de549e7af1147d14f9bdd48659b7018732af34022cc734565af1f742e9" diff --git a/pyproject.toml b/pyproject.toml index 813cbd335..d9ea48b81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"] [tool.poetry.dependencies] python = "^3.11" +deepdiff = "^7.0.1" gymnasium = "^0.28.0" h5py = "^3.9.0" numba = "^0.57.1" diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 777ef9636..b694abcf5 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -2,12 +2,13 @@ import pickle import sys from itertools import starmap -from typing import cast +from typing import Any, cast import networkx as nx import numpy as np import pytest import torch +from deepdiff import DeepDiff from tianshou.data import Batch, to_numpy, to_torch @@ -630,6 +631,78 @@ def test_subbatch_dict_and_batch_types() -> None: assert Batch(a={"x": 1}) == Batch(a=Batch(x=1)) +class TestBatchToDict: + @staticmethod + def test_to_dict_empty_batch_no_recurse() -> None: + batch = Batch() + expected: dict[Any, Any] = {} + assert batch.to_dict() == expected + + @staticmethod + def test_to_dict_with_simple_values_recurse() -> None: + batch = Batch(a=1, b="two", c=np.array([3, 4])) + expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])} + assert not DeepDiff(batch.to_dict(recurse=True), expected) + + @staticmethod + def test_to_dict_simple() -> None: + batch = Batch(a=1, b="two") + expected = {"a": np.asanyarray(1), "b": "two"} + assert batch.to_dict() == expected + + @staticmethod + def test_to_dict_nested_batch_no_recurse() -> None: + nested_batch = Batch(c=3) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": nested_batch} + assert not DeepDiff(batch.to_dict(), expected) + + @staticmethod + def test_to_dict_nested_batch_recurse() -> None: + nested_batch = Batch(c=3) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}} + assert not DeepDiff(batch.to_dict(recurse=True), expected) + + @staticmethod + def test_to_dict_multiple_nested_batch_recurse() -> None: + nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300]) + batch = Batch(a=1, b=nested_batch) + expected = { + "a": np.asanyarray(1), + "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])}, + } + assert not DeepDiff(batch.to_dict(recurse=True), expected) + + @staticmethod + def test_to_dict_array() -> None: + batch = Batch(a=np.array([1, 2, 3])) + expected = {"a": np.array([1, 2, 3])} + assert not DeepDiff(batch.to_dict(), expected) + + @staticmethod + def test_to_dict_nested_batch_with_array() -> None: + nested_batch = Batch(c=np.array([4, 5])) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}} + assert not DeepDiff(batch.to_dict(recurse=True), expected) + + @staticmethod + def test_to_dict_torch_tensor() -> None: + t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() + batch = Batch(a=t1) + t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy() + expected = {"a": t2} + assert not DeepDiff(batch.to_dict(), expected) + + @staticmethod + def test_to_dict_nested_batch_with_torch_tensor() -> None: + nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy()) + batch = Batch(a=1, b=nested_batch) + expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}} + assert not DeepDiff(batch.to_dict(recurse=True), expected) + + if __name__ == "__main__": test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e66fbcc6e..fe76df37f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -399,7 +399,7 @@ def split( """ ... - def to_dict(self) -> dict[str, Any]: + def to_dict(self, recurse: bool = False) -> dict[str, Any]: ... def to_list_of_dicts(self) -> list[dict[str, Any]]: @@ -436,11 +436,11 @@ def __init__( # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore - def to_dict(self) -> dict[str, Any]: + def to_dict(self, recurse: bool = False) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): - if isinstance(v, Batch): - v = v.to_dict() + if recurse and isinstance(v, Batch): + v = v.to_dict(recurse=recurse) result[k] = v return result From 98d611ccee59f17c8148e037895b1dffdf6d3b8e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:51:46 +0200 Subject: [PATCH 06/11] Use DeepDiff to test for Batch equality * Note: `Batch.to_numpy()` should be extended to support also a non in-place operation. --- tianshou/data/batch.py | 32 ++++++-------------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fe76df37f..8e14dafd2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -17,6 +17,7 @@ import numpy as np import torch +from deepdiff import DeepDiff _SingleIndexType = slice | int | EllipsisType IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] @@ -507,33 +508,12 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - this_dict = self.__dict__ - other_dict = other.__dict__ + self.to_numpy() + other.to_numpy() + this_dict = self.to_dict(recurse=True) + other_dict = other.to_dict(recurse=True) - if len(this_dict) != len(other_dict): - return False - for batch_key, obs in this_dict.items(): - if batch_key not in other_dict: - return False - - other_val = other.__dict__[batch_key] - - if batch_key in other_dict: - if isinstance(obs, Batch) and isinstance(other_val, Batch): - if not obs == other_val: - return False - elif isinstance(obs, np.ndarray) and isinstance(other_val, np.ndarray): - if not np.all(np.equal(obs.shape, other_val.shape)): - return False - if not np.all(np.equal(obs, other_val)): - return False - elif isinstance(obs, torch.Tensor) and isinstance(other_val, torch.Tensor): - if not torch.equal(obs, other_val): - return False - else: - return False - - return True + return not DeepDiff(this_dict, other_dict) def __iter__(self) -> Iterator[Self]: # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea From 164cf84122913044aa4f90d1ec97e5f1e29f2db4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:55:16 +0200 Subject: [PATCH 07/11] Implement non-inplace to_numpy for Batch * Breaking change: Previous in-place `Batch.to_numpy` is now `Batch.to_numpy_` (following naming convention of other in-place methods). * Update places where in-place was expected * Add tests for both to_numpy/to_numpy_ --- docs/01_tutorials/03_batch.rst | 4 ++-- docs/02_notebooks/L1_Batch.ipynb | 2 +- test/base/test_batch.py | 26 +++++++++++++++++++++++++- tianshou/data/batch.py | 31 ++++++++++++++++++++++++------- tianshou/data/utils/converter.py | 2 +- 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst index 71f82f84e..46fa86b3d 100644 --- a/docs/01_tutorials/03_batch.rst +++ b/docs/01_tutorials/03_batch.rst @@ -485,8 +485,8 @@ Miscellaneous Notes tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) - >>> # data.to_numpy is also available - >>> data.to_numpy() + >>> # data.to_numpy_ is also available + >>> data.to_numpy_() .. raw:: html diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 9c80349da..54008ee64 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -331,7 +331,7 @@ }, "outputs": [], "source": [ - "batch_cat.to_numpy()\n", + "batch_cat.to_numpy_()\n", "print(batch_cat)\n", "batch_cat.to_torch()\n", "print(batch_cat)" diff --git a/test/base/test_batch.py b/test/base/test_batch.py index b694abcf5..0ce219e75 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -478,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None: a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] batch.to_torch() - batch.to_numpy() + batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] assert a_mem_addr_new == a_mem_addr_orig @@ -703,6 +703,30 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: assert not DeepDiff(batch.to_dict(recurse=True), expected) +class TestToNumpy: + """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` .""" + + @staticmethod + def test_to_numpy() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + new_batch: Batch = Batch.to_numpy(batch) + assert id(batch) != id(new_batch) + assert isinstance(batch.b, torch.Tensor) + assert isinstance(batch.c.d, torch.Tensor) + + assert isinstance(new_batch.b, np.ndarray) + assert isinstance(new_batch.c.d, np.ndarray) + + @staticmethod + def test_to_numpy_() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + id_batch = id(batch) + batch.to_numpy_() + assert id_batch == id(batch) + assert isinstance(batch.b, np.ndarray) + assert isinstance(batch.c.d, np.ndarray) + + if __name__ == "__main__": test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 8e14dafd2..4eee0cd81 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -272,7 +272,12 @@ def __iter__(self) -> Iterator[Self]: def __eq__(self, other: Any) -> bool: ... - def to_numpy(self) -> None: + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" + ... + + def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -508,10 +513,10 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - self.to_numpy() - other.to_numpy() - this_dict = self.to_dict(recurse=True) - other_dict = other.to_dict(recurse=True) + this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) + other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) + this_dict = this_batch_no_torch_tensor.to_dict(recurse=True) + other_dict = other_batch_no_torch_tensor.to_dict(recurse=True) return not DeepDiff(this_dict, other_dict) @@ -614,12 +619,24 @@ def __repr__(self) -> str: self_str = self.__class__.__name__ + "()" return self_str - def to_numpy(self) -> None: + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + batch_dict = deepcopy(batch) + for batch_key, obj in batch_dict.items(): + if isinstance(obj, torch.Tensor): + batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj = Batch.to_numpy(obj) + batch_dict.__dict__[batch_key] = obj + + return batch_dict + + def to_numpy_(self) -> None: for batch_key, obj in self.items(): if isinstance(obj, torch.Tensor): self.__dict__[batch_key] = obj.detach().cpu().numpy() elif isinstance(obj, Batch): - obj.to_numpy() + obj.to_numpy_() def to_torch( self, diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 2df462da5..7edf3ff45 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray: return np.array(None, dtype=object) if isinstance(x, dict | Batch): x = Batch(x) if isinstance(x, dict) else deepcopy(x) - x.to_numpy() + x.to_numpy_() return x if isinstance(x, list | tuple): return to_numpy(_parse_value(x)) From def4f0df954097ccaa185a9b9dd06004eb369c4f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:48:00 +0200 Subject: [PATCH 08/11] Batch.to_dict should be a rremain by default a recursive conversion --- test/base/test_batch.py | 2 +- tianshou/data/batch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index b45ea00c5..d1b459b76 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -655,7 +655,7 @@ def test_to_dict_nested_batch_no_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": nested_batch} - assert not DeepDiff(batch.to_dict(), expected) + assert not DeepDiff(batch.to_dict(recurse=False), expected) @staticmethod def test_to_dict_nested_batch_recurse() -> None: diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 95b25f741..b09847a20 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -405,7 +405,7 @@ def split( """ ... - def to_dict(self, recurse: bool = False) -> dict[str, Any]: + def to_dict(self, recurse: bool = True) -> dict[str, Any]: ... def to_list_of_dicts(self) -> list[dict[str, Any]]: @@ -442,7 +442,7 @@ def __init__( # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore - def to_dict(self, recurse: bool = False) -> dict[str, Any]: + def to_dict(self, recurse: bool = True) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): if recurse and isinstance(v, Batch): From cebde459a2cc3565b5f81e9591b13f9de88e9fd2 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:01:43 +0200 Subject: [PATCH 09/11] Change argname from recurse -> recursive --- test/base/test_batch.py | 12 ++++++------ tianshou/data/batch.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index d1b459b76..82ff4a3fb 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -642,7 +642,7 @@ def test_to_dict_empty_batch_no_recurse() -> None: def test_to_dict_with_simple_values_recurse() -> None: batch = Batch(a=1, b="two", c=np.array([3, 4])) expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])} - assert not DeepDiff(batch.to_dict(recurse=True), expected) + assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_simple() -> None: @@ -655,14 +655,14 @@ def test_to_dict_nested_batch_no_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": nested_batch} - assert not DeepDiff(batch.to_dict(recurse=False), expected) + assert not DeepDiff(batch.to_dict(recursive=False), expected) @staticmethod def test_to_dict_nested_batch_recurse() -> None: nested_batch = Batch(c=3) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}} - assert not DeepDiff(batch.to_dict(recurse=True), expected) + assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_multiple_nested_batch_recurse() -> None: @@ -672,7 +672,7 @@ def test_to_dict_multiple_nested_batch_recurse() -> None: "a": np.asanyarray(1), "b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])}, } - assert not DeepDiff(batch.to_dict(recurse=True), expected) + assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_array() -> None: @@ -685,7 +685,7 @@ def test_to_dict_nested_batch_with_array() -> None: nested_batch = Batch(c=np.array([4, 5])) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}} - assert not DeepDiff(batch.to_dict(recurse=True), expected) + assert not DeepDiff(batch.to_dict(recursive=True), expected) @staticmethod def test_to_dict_torch_tensor() -> None: @@ -700,7 +700,7 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy()) batch = Batch(a=1, b=nested_batch) expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}} - assert not DeepDiff(batch.to_dict(recurse=True), expected) + assert not DeepDiff(batch.to_dict(recursive=True), expected) class TestToNumpy: diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index b09847a20..d911788c6 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -442,11 +442,11 @@ def __init__( # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore - def to_dict(self, recurse: bool = True) -> dict[str, Any]: + def to_dict(self, recursive: bool = True) -> dict[str, Any]: result = {} for k, v in self.__dict__.items(): - if recurse and isinstance(v, Batch): - v = v.to_dict(recurse=recurse) + if recursive and isinstance(v, Batch): + v = v.to_dict(recursive=recursive) result[k] = v return result @@ -518,8 +518,8 @@ def __eq__(self, other: Any) -> bool: this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) - this_dict = this_batch_no_torch_tensor.to_dict(recurse=True) - other_dict = other_batch_no_torch_tensor.to_dict(recurse=True) + this_dict = this_batch_no_torch_tensor.to_dict(recursive=True) + other_dict = other_batch_no_torch_tensor.to_dict(recursive=True) return not DeepDiff(this_dict, other_dict) From bee533c4103a4422a20f452d6895334c132b3939 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:08:55 +0200 Subject: [PATCH 10/11] Allow two (same/different) Batch objs to be tested for equality #1098 --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 126f81a89..786ec1691 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 - `SamplingConfig` supports `batch_size=None`. #1077 +- Batch received new method: `to_numpy_`. #1098 +- `to_dict` in Batch supports also non-recursive conversion. #1098 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -34,6 +36,7 @@ expicitly or pass `reset_before_collect=True` . #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 +- The method `to_numpy` in `data.utils.batch.Batch` is not in-place anymore. Instead, a new method `to_numpy_` does the conversion in-place. #1098 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 From a04df6bbb559455172e6f4d0c85f0a83a42bec8a Mon Sep 17 00:00:00 2001 From: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:11:30 +0200 Subject: [PATCH 11/11] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 786ec1691..24c72ed05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - `SamplingConfig` supports `batch_size=None`. #1077 - Batch received new method: `to_numpy_`. #1098 - `to_dict` in Batch supports also non-recursive conversion. #1098 +- Batch __eq__ now implemented, semantic equality check of batches is now possible. #1098 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063