Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow two (same/different) Batch objs to be tested for equality #1098

Merged
merged 12 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ Miscellaneous Notes
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # data.to_numpy is also available
>>> data.to_numpy()
>>> # data.to_numpy_ is also available
>>> data.to_numpy_()

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L1_Batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
},
"outputs": [],
"source": [
"batch_cat.to_numpy()\n",
"batch_cat.to_numpy_()\n",
"print(batch_cat)\n",
"batch_cat.to_torch()\n",
"print(batch_cat)"
Expand Down
35 changes: 33 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ exclude = ["test/*", "examples/*", "docs/*"]

[tool.poetry.dependencies]
python = "^3.11"
deepdiff = "^7.0.1"
gymnasium = "^0.28.0"
h5py = "^3.9.0"
numba = "^0.57.1"
Expand Down
166 changes: 164 additions & 2 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import pickle
import sys
from itertools import starmap
from typing import cast
from typing import Any, cast

import networkx as nx
import numpy as np
import pytest
import torch
from deepdiff import DeepDiff

from tianshou.data import Batch, to_numpy, to_torch

Expand Down Expand Up @@ -477,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch()
batch.to_numpy()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
assert a_mem_addr_new == a_mem_addr_orig
Expand Down Expand Up @@ -565,6 +566,167 @@ def test_batch_standard_compatibility() -> None:
Batch()[0]


class TestBatchEquality:
@staticmethod
def test_keys_different() -> None:
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2

@staticmethod
def test_keys_missing() -> None:
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2

@staticmethod
def test_types_keys_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2

@staticmethod
def test_array_types_different() -> None:
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2

@staticmethod
def test_nested_values_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_nested_shapes_different() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2]

@staticmethod
def test_slice_ellipsis_equal() -> None:
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:]

@staticmethod
def test_empty_batches() -> None:
assert Batch() == Batch()

@staticmethod
def test_different_order_keys() -> None:
assert Batch(a=1, b=2) == Batch(b=2, a=1)

@staticmethod
def test_tuple_and_list_types() -> None:
assert Batch(a=(1, 2)) == Batch(a=[1, 2])

@staticmethod
def test_subbatch_dict_and_batch_types() -> None:
assert Batch(a={"x": 1}) == Batch(a=Batch(x=1))


class TestBatchToDict:
@staticmethod
def test_to_dict_empty_batch_no_recurse() -> None:
batch = Batch()
expected: dict[Any, Any] = {}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_with_simple_values_recurse() -> None:
batch = Batch(a=1, b="two", c=np.array([3, 4]))
expected = {"a": np.asanyarray(1), "b": "two", "c": np.array([3, 4])}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_simple() -> None:
batch = Batch(a=1, b="two")
expected = {"a": np.asanyarray(1), "b": "two"}
assert batch.to_dict() == expected

@staticmethod
def test_to_dict_nested_batch_no_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": nested_batch}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_recurse() -> None:
nested_batch = Batch(c=3)
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.asanyarray(3)}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_multiple_nested_batch_recurse() -> None:
nested_batch = Batch(c=Batch(e=3), d=[100, 200, 300])
batch = Batch(a=1, b=nested_batch)
expected = {
"a": np.asanyarray(1),
"b": {"c": {"e": np.asanyarray(3)}, "d": np.array([100, 200, 300])},
}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_array() -> None:
batch = Batch(a=np.array([1, 2, 3]))
expected = {"a": np.array([1, 2, 3])}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_array() -> None:
nested_batch = Batch(c=np.array([4, 5]))
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": np.array([4, 5])}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)

@staticmethod
def test_to_dict_torch_tensor() -> None:
t1 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
batch = Batch(a=t1)
t2 = torch.tensor([1.0, 2.0]).detach().cpu().numpy()
expected = {"a": t2}
assert not DeepDiff(batch.to_dict(), expected)

@staticmethod
def test_to_dict_nested_batch_with_torch_tensor() -> None:
nested_batch = Batch(c=torch.tensor([4, 5]).detach().cpu().numpy())
batch = Batch(a=1, b=nested_batch)
expected = {"a": np.asanyarray(1), "b": {"c": torch.tensor([4, 5]).detach().cpu().numpy()}}
assert not DeepDiff(batch.to_dict(recurse=True), expected)


class TestToNumpy:
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""

@staticmethod
def test_to_numpy() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
new_batch: Batch = Batch.to_numpy(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)

assert isinstance(new_batch.b, np.ndarray)
assert isinstance(new_batch.c.d, np.ndarray)

@staticmethod
def test_to_numpy_() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
id_batch = id(batch)
batch.to_numpy_()
assert id_batch == id(batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand Down
46 changes: 39 additions & 7 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import torch
from deepdiff import DeepDiff

_SingleIndexType = slice | int | EllipsisType
IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...]
Expand Down Expand Up @@ -268,7 +269,15 @@ def __repr__(self) -> str:
def __iter__(self) -> Iterator[Self]:
...

def to_numpy(self) -> None:
def __eq__(self, other: Any) -> bool:
...

@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
...

def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...

Expand Down Expand Up @@ -396,7 +405,7 @@ def split(
"""
...

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = False) -> dict[str, Any]:
...

def to_list_of_dicts(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -433,11 +442,11 @@ def __init__(
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore

def to_dict(self) -> dict[str, Any]:
def to_dict(self, recurse: bool = False) -> dict[str, Any]:
dantp-ai marked this conversation as resolved.
Show resolved Hide resolved
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
if recurse and isinstance(v, Batch):
dantp-ai marked this conversation as resolved.
Show resolved Hide resolved
v = v.to_dict(recurse=recurse)
result[k] = v
return result

Expand Down Expand Up @@ -503,6 +512,17 @@ def __getitem__(self, index: str | IndexType) -> Any:
return new_batch
raise IndexError("Cannot access item from empty Batch object.")

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
this_dict = this_batch_no_torch_tensor.to_dict(recurse=True)
other_dict = other_batch_no_torch_tensor.to_dict(recurse=True)

return not DeepDiff(this_dict, other_dict)

def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
Expand Down Expand Up @@ -602,12 +622,24 @@ def __repr__(self) -> str:
self_str = self.__class__.__name__ + "()"
return self_str

def to_numpy(self) -> None:
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
batch_dict = deepcopy(batch)
for batch_key, obj in batch_dict.items():
if isinstance(obj, torch.Tensor):
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj = Batch.to_numpy(obj)
batch_dict.__dict__[batch_key] = obj

return batch_dict

def to_numpy_(self) -> None:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj.to_numpy()
obj.to_numpy_()

def to_torch(
self,
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
return np.array(None, dtype=object)
if isinstance(x, dict | Batch):
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
x.to_numpy()
x.to_numpy_()
return x
if isinstance(x, list | tuple):
return to_numpy(_parse_value(x))
Expand Down
Loading