-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
JPEG encoding and decoding if the observation is an image #275
base: main
Are you sure you want to change the base?
Changes from 5 commits
8106db5
5bf0cc3
a4762c8
a348f1a
d843f9f
2aa2f44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
import pathlib | ||
from itertools import zip_longest | ||
from typing import Any, Dict, Iterable, Optional, Sequence | ||
from PIL import Image | ||
import io | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
|
@@ -178,11 +180,20 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): | |
return pa.StructArray.from_arrays(arrays, names=names) | ||
elif isinstance(space, _FIXEDLIST_SPACES): | ||
values = np.asarray(values) | ||
assert values.shape[1:] == space.shape | ||
values = values.reshape(values.shape[0], -1) | ||
values = np.pad(values, ((0, pad), (0, 0))) | ||
dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) | ||
return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) | ||
if values.shape == (4, 84, 84) and values.dtype == np.uint8: # check for image observation (4 stacked greyscale images) | ||
jpeg_bytes = [] | ||
for frame in values: | ||
img = Image.fromarray(frame) | ||
buffer = io.BytesIO() | ||
img.save(buffer, format="JPEG") | ||
jpeg_bytes.append(buffer.getvalue()) | ||
return pa.array(jpeg_bytes, type=pa.binary()) | ||
else: | ||
assert values.shape[1:] == space.shape | ||
values = values.reshape(values.shape[0], -1) | ||
values = np.pad(values, ((0, pad), (0, 0))) | ||
dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) | ||
return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) | ||
elif isinstance(space, gym.spaces.Discrete): | ||
values = np.asarray(values).reshape(-1, 1) | ||
values = np.pad(values, ((0, pad), (0, 0))) | ||
|
@@ -207,8 +218,15 @@ def _decode_space(space, values: pa.Array): | |
] | ||
) | ||
elif isinstance(space, _FIXEDLIST_SPACES): | ||
data = np.stack(values.to_numpy(zero_copy_only=False)) | ||
return data.reshape(-1, *space.shape) | ||
if values.type == pa.binary(): # check for binary data (JPEG) | ||
jpeg_images = [] | ||
for jpeg_bytes in values: | ||
image = Image.open(io.BytesIO(jpeg_bytes)).convert("L") # decode JPEG and convert to greyscale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should work with non-grayscale images as well. |
||
jpeg_images.append(np.array(image)) | ||
return np.stack(jpeg_images) | ||
else: | ||
data = np.stack(values.to_numpy(zero_copy_only=False)) | ||
return data.reshape(-1, *space.shape) | ||
elif isinstance(space, gym.spaces.Discrete): | ||
return values.to_numpy() | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should not constrain ourselves with 84 x 84 images, but we should accept any size.
Same reasoning for the 4 dimension.
I think the discriminant here is to have Box, with at least 2 dim, and with type uint8. Then you can put a logging.warn saying you are considering it as image, and if this is not intended to disable it via a flag "image_observation" which is defaulted to None. Then you can compute the value of the flag during init (and warn just once). I will clarify later in our meeting.