forked from kschweig/OfflineRL
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathex_06.py
110 lines (87 loc) · 4.25 KB
/
ex_06.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from source.train_online import train_online
from source.train_offline import train_offline
from source.offline_ds_evaluation.evaluator import Evaluator
from source.offline_ds_evaluation.metrics_manager import MetricsManager
from multiprocessing import Pool
import os
import copy
import pickle
import numpy as np
import argparse
# project parameters
env = "SpaceInvaders-MinAtar-v0"
discount = 0.99
buffer_types = ["random", "mixed", "er", "noisy", "fully"]
agent_types = ["BC", "BVE", "MCE", "DQN", "QRDQN", "REM", "BCQ", "CQL", "CRR"]
multiple_useruns = 5
# experiment parameters
experiment = 6
seed = 42
# hyperparameters for online training
behavioral = "DQN"
buffer_size = 500000
transitions_online = 2000000
# hyperparameters for offline training
transitions_offline = 5 * transitions_online
batch_size = 128
lr = [1e-4] * len(agent_types)
def create_ds(args):
train_online(experiment=experiment, agent_type=behavioral, discount=discount, envid=env,
transitions=transitions_online, buffer_size=buffer_size,
run=args, seed=seed)
def train(args):
use_run, run, dataset = args
for a, agent in enumerate(agent_types):
for bt in range(len(buffer_types)):
if 0 <= dataset != bt:
continue
train_offline(experiment=experiment, envid=env, agent_type=agent, buffer_type=buffer_types[bt],
discount=discount, transitions=transitions_offline, batch_size=batch_size, lr=lr[a],
use_run=use_run, run=run, seed=seed+run, use_remaining_reward=(agent == "MCE"))
def assess_env(args):
use_run = args
with open(os.path.join("data", f"ex{experiment}", f"{env}_run{use_run}_er.pkl"), "rb") as f:
buffer = pickle.load(f)
state_limits = []
for axis in range(len(buffer.state[0])):
state_limits.append(np.min(buffer.state[:, axis]))
state_limits.append(np.max(buffer.state[:, axis]))
action_limits = copy.deepcopy(state_limits)
action_limits.append(np.min(buffer.action))
action_limits.append(np.max(buffer.action))
results, returns, actions, entropies, sparsities, episode_lengts = [], [], [], [], [], []
mm = MetricsManager(experiment)
for buffer_type in buffer_types:
with open(os.path.join("data", f"ex{experiment}", f"{env}_run{use_run}_{buffer_type}.pkl"), "rb") as f:
buffer = pickle.load(f)
evaluator = Evaluator(env, buffer_type, buffer.state, buffer.action, buffer.reward,
np.invert(buffer.not_done))
if env == 'CartPole-v1' or env == 'MountainCar-v0':
rets, us, usa, ents, sps, epls = evaluator.evaluate(state_limits=state_limits, action_limits=action_limits,
epochs=10)
else:
rets, us, usa, ents, sps, epls = evaluator.evaluate(epochs=10)
returns.append(rets)
entropies.append(ents)
sparsities.append(sps)
episode_lengts.append(epls)
actions.append(buffer.action.flatten().tolist())
results.append([env, buffer_type, (np.mean(rets), np.std(rets)), usa, (np.mean(ents), np.std(ents))])
mm.append([env, buffer_type, (np.mean(rets), np.std(rets)), us, usa, (np.mean(ents), np.std(ents)),
(np.mean(sps), np.std(sps)), (np.mean(epls), np.std(epls))])
with open(os.path.join("data", f"ex{experiment}", f"metrics_{env}_run{use_run}.pkl"), "wb") as f:
pickle.dump(mm, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--online", action="store_true") # run online task
parser.add_argument("--run", default=1, type=int) # which number of run for dataset?
parser.add_argument("--dataset", default=-1, type=int) # which dataset to use
args = parser.parse_args()
assert args.dataset < 5, "dataset must be within the created ones or negative for all!"
if args.online:
with Pool(multiple_useruns, maxtasksperchild=1) as p:
p.map(create_ds, range(1, multiple_useruns + 1))
p.map(assess_env, range(1, multiple_useruns + 1))
else:
with Pool(multiple_useruns, maxtasksperchild=1) as p:
p.map(train, zip(range(1, multiple_useruns + 1), [args.run] * 5, [args.dataset]*5))