-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathIS_testing_working_april15.py
36 lines (33 loc) · 1.2 KB
/
IS_testing_working_april15.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from OffPolicyAgent import OffPolicyAgent
from random_walk_env import RandomWalkEnv
import numpy as np
import gym
# testing
env = RandomWalkEnv(10)
lr=.01
discount=.5
# 0.5 probability of choosing left or right in randomwalk
uniform_random_behavior=np.full(shape=(10,2), fill_value=0.5, dtype=np.float)
# construct target policy: deterministic to the right
target=np.zeros(shape=(10,2), dtype=np.float)
for i in range(10):
target[i,1] = 1.0
agent = OffPolicyAgent('RandomWalk', 256, env, target, uniform_random_behavior, lr, discount)
# print out initial value function
states = agent.construct_features(range(10))
print(agent.model.predict([states, np.array([0.]*10)]))
true_value = [discount**i for i in reversed(range(10))]
mses=[]
for j in range(25):
# generate 100 episodes, training after each
for i in range(50):
agent.generate_episode()
agent.train_batch(32, 8)
# print current value function
states = agent.construct_features(range(10))
prediction = agent.model.predict([states, np.array([0.]*10)])
print(prediction)
prediction = prediction.flatten()
mse = np.mean(np.square(true_value-prediction))
mses.append(mse)
print("final MeanSquareError", mses[-1])