-
Notifications
You must be signed in to change notification settings - Fork 0
/
buffer.py
106 lines (96 loc) · 3.68 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from torch.utils.data import IterableDataset, DataLoader
import random
import uuid
from gym import spaces
from typing import Union, Optional, List, Dict
import numpy as np
from pathlib import Path
import datetime
class ReplayBuffer:
def __init__(self, action_space: spaces.Space, balance: bool = True):
self.current_episode: Optional[list] = []
self.action_space = action_space
self.balance = balance
self.episodes = []
def start_episode(self, obs: dict):
transition = obs.copy()
transition['action'] = np.zeros(self.action_space.shape)
transition['reward'] = 0.0
transition['discount'] = 1.0
self.current_episode = [transition]
def add(self, obs: dict, action: np.ndarray, reward: float, done: bool,
info: dict):
transition = obs.copy()
transition['action'] = action
transition['reward'] = reward
transition['discount'] = info.get('discount',
np.array(1 - float(done)))
self.current_episode.append(transition)
if done:
episode = {
k: [t[k] for t in self.current_episode]
for k in self.current_episode[0]
}
episode = {k: self.convert(v) for k, v in episode.items()}
self.episodes.append(episode)
self.current_episode = []
def sample_single_episode(self, length: int):
episode = random.choice(self.episodes)
total = len(next(iter(episode.values())))
available = total - length
while True:
if available < 1:
print(f'Skipped short episode of length {available}.')
if self.balance:
index = min(random.randint(0, total), available)
else:
index = int(random.randint(0, available))
episode = {k: v[index:index + length] for k, v in episode.items()}
return episode
def sample(self, batch_size: int, length: int):
"""
Args:
length: number of observations, or transition + 1
"""
episodes = [self.sample_single_episode(length) for _ in range(batch_size)]
batch = {}
for key in episodes[0]:
batch[key] = np.array([ep[key] for ep in episodes])
return batch
def convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
dtype = np.float32
elif np.issubdtype(value.dtype, np.signedinteger):
dtype = np.int32
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
if __name__ == '__main__':
from env import make_dmc_env
import time
env = make_dmc_env(name='cartpole_swingup')
replay_buffer = ReplayBuffer(action_space=env.action_space, balance=True)
steps = 0
obs = env.reset()
replay_buffer.start_episode(obs)
start = time.perf_counter()
while True:
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
replay_buffer.add(obs, action, reward, done, info)
if done:
obs = env.reset()
replay_buffer.start_episode(obs)
steps += 1
if steps % 2500 == 0:
# import ipdb; ipdb.set_trace()
data = replay_buffer.sample(batch_size=32, length=15)
for key in data:
print(key, data[key].shape)
elapsed = time.perf_counter() - start
print(
f'steps: {steps}, frames: {steps * 2}, time: {elapsed:.2f}s, fps: {steps * 2 / elapsed:.2f}'
)