-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathpreprocessing.py
219 lines (182 loc) · 7.83 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import warnings
from typing import Dict, Tuple, Union
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F
def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
"""
Check if an image observation space (see ``is_image_space``)
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
Use a heuristic that channel dimension is the smallest of the three.
If second dimension is smallest, raise an exception (no support).
:param observation_space:
:return: True if observation space is channels-first image, False if channels-last.
"""
smallest_dimension = np.argmin(observation_space.shape).item()
if smallest_dimension == 1:
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
return smallest_dimension == 0
def is_image_space(
observation_space: spaces.Space,
channels_last: bool = True,
check_channels: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
The check is conservative, so that it returns False
if there is a doubt.
Valid images: RGB, RGBD, GrayScale with values in [0, 255]
:param observation_space:
:param channels_last:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:return:
"""
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if observation_space.dtype != np.uint8:
return False
# Check the value range
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
return False
# Skip channels check
if not check_channels:
return True
# Check the number of channels
if channels_last:
n_channels = observation_space.shape[-1]
else:
n_channels = observation_space.shape[0]
# RGB, RGBD, GrayScale
return n_channels in [1, 3, 4]
return False
def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
"""
Handle the different cases for images as PyTorch use channel first format.
:param observation:
:param observation_space:
:return: channel first observation if observation is an image
"""
# Avoid circular import
from stable_baselines3.common.vec_env import VecTransposeImage
if is_image_space(observation_space):
if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
observation = transpose_obs
return observation
def preprocess_obs(
obs: th.Tensor,
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
For discrete observations, it create a one hot vector.
:param obs: Observation
:param observation_space:
:param normalize_images: Whether to normalize images or not
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Box):
if is_image_space(observation_space) and normalize_images:
return obs.float() / 255.0
return obs.float()
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return F.one_hot(obs.long(), num_classes=observation_space.n).float()
elif isinstance(observation_space, spaces.MultiDiscrete):
# Tensor concatenation of one hot encodings of each Categorical sub-space
return th.cat(
[
F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
],
dim=-1,
).view(obs.shape[0], sum(observation_space.nvec))
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
elif isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs
else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
def get_obs_shape(
observation_space: spaces.Space,
) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return (int(observation_space.n),)
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
else:
raise NotImplementedError(f"{observation_space} observation space is not supported")
def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
"""
Get the dimension of the observation space when flattened.
It does not apply to image observation space.
Used by the ``FlattenExtractor`` to compute the input shape.
:param observation_space:
:return:
"""
# See issue https://github.com/openai/gym/issues/1915
# it may be a problem for Dict/Tuple spaces too...
if isinstance(observation_space, spaces.MultiDiscrete):
return sum(observation_space.nvec)
else:
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
def get_action_dim(action_space: spaces.Space) -> int:
"""
Get the dimension of the action space.
:param action_space:
:return:
"""
if isinstance(action_space, spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
elif isinstance(action_space, spaces.MultiDiscrete):
# Number of discrete actions
return int(len(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
def check_for_nested_spaces(obs_space: spaces.Space):
"""
Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
If so, raise an Exception informing that there is no support for this.
:param obs_space: an observation space
:return:
"""
if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
for sub_space in sub_spaces:
if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
raise NotImplementedError(
"Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
)