Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

JPEG encoding and decoding if the observation is an image #275

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions .idea/Minari.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 25 additions & 7 deletions minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pathlib
from itertools import zip_longest
from typing import Any, Dict, Iterable, Optional, Sequence
from PIL import Image
import io

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -178,11 +180,20 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0):
return pa.StructArray.from_arrays(arrays, names=names)
elif isinstance(space, _FIXEDLIST_SPACES):
values = np.asarray(values)
assert values.shape[1:] == space.shape
values = values.reshape(values.shape[0], -1)
values = np.pad(values, ((0, pad), (0, 0)))
dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1])
return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype)
if values.shape == (4, 84, 84) and values.dtype == np.uint8: # check for image observation (4 stacked greyscale images)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not constrain ourselves with 84 x 84 images, but we should accept any size.

Same reasoning for the 4 dimension.

I think the discriminant here is to have Box, with at least 2 dim, and with type uint8. Then you can put a logging.warn saying you are considering it as image, and if this is not intended to disable it via a flag "image_observation" which is defaulted to None. Then you can compute the value of the flag during init (and warn just once). I will clarify later in our meeting.

jpeg_bytes = []
for frame in values:
img = Image.fromarray(frame)
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
jpeg_bytes.append(buffer.getvalue())
return pa.array(jpeg_bytes, type=pa.binary())
else:
assert values.shape[1:] == space.shape
values = values.reshape(values.shape[0], -1)
values = np.pad(values, ((0, pad), (0, 0)))
dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1])
return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype)
elif isinstance(space, gym.spaces.Discrete):
values = np.asarray(values).reshape(-1, 1)
values = np.pad(values, ((0, pad), (0, 0)))
Expand All @@ -207,8 +218,15 @@ def _decode_space(space, values: pa.Array):
]
)
elif isinstance(space, _FIXEDLIST_SPACES):
data = np.stack(values.to_numpy(zero_copy_only=False))
return data.reshape(-1, *space.shape)
if values.type == pa.binary(): # check for binary data (JPEG)
jpeg_images = []
for jpeg_bytes in values:
image = Image.open(io.BytesIO(jpeg_bytes)).convert("L") # decode JPEG and convert to greyscale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should work with non-grayscale images as well.

jpeg_images.append(np.array(image))
return np.stack(jpeg_images)
else:
data = np.stack(values.to_numpy(zero_copy_only=False))
return data.reshape(-1, *space.shape)
elif isinstance(space, gym.spaces.Discrete):
return values.to_numpy()
else:
Expand Down
22 changes: 19 additions & 3 deletions minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections import OrderedDict
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union

from PIL import Image
import io
import gymnasium as gym
import numpy as np

Expand Down Expand Up @@ -107,6 +108,13 @@ def _decode_space(
assert isinstance(hdf_ref, h5py.Dataset)
result = map(lambda string: string.decode("utf-8"), hdf_ref[()])
return list(result)
elif isinstance(hdf_ref, h5py.Dataset) and hdf_ref.dtype.kind == 'O': # check for binary data (JPEG)
jpeg_images = []
jpeg_bytes_list = hdf_ref[()]
for jpeg_bytes in jpeg_bytes_list:
image = Image.open(io.BytesIO(jpeg_bytes)).convert("L") # decode JPEG and convert to greyscale
jpeg_images.append(np.array(image, dtype=np.uint8))
return np.stack(jpeg_images)
else:
assert isinstance(hdf_ref, h5py.Dataset)
return hdf_ref[()]
Expand Down Expand Up @@ -199,6 +207,15 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group):
if isinstance(data, dict):
episode_group_to_clear = _get_from_h5py(episode_group, key)
_add_episode_to_group(data, episode_group_to_clear)
elif isinstance(data, np.ndarray) and data.shape == (4, 84, 84) and data.dtype == np.uint8: # check for image observation (4 stacked greyscale images)
jpeg_bytes = []
for frame in data:
img = Image.fromarray(frame, mode="L")
buffer = io.BytesIO()
img.save(buffer, format="JPEG")
jpeg_bytes.append(buffer.getvalue())
dt = h5py.special_dtype(vlen=bytes)
episode_group.create_dataset(key, data=np.array(jpeg_bytes, dtype=object), dtype=dt, chunks=True)
elif isinstance(data, tuple):
dict_data = {f"_index_{i}": subdata for i, subdata in enumerate(data)}
episode_group_to_clear = _get_from_h5py(episode_group, key)
Expand All @@ -209,7 +226,6 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group):
dict_data = {key: [entry[key] for entry in data] for key in data[0].keys()}
episode_group_to_clear = _get_from_h5py(episode_group, key)
_add_episode_to_group(dict_data, episode_group_to_clear)

# leaf data
elif key in episode_group:
dataset = episode_group[key]
Expand Down Expand Up @@ -267,4 +283,4 @@ def unflatten_dict(d: Dict) -> Dict:
current[key] = {}
current = current[key]
current[keys[-1]] = v
return result
return result