-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualize_critic.py
32 lines (28 loc) · 1002 Bytes
/
visualize_critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import numpy as np
from env.batched_env import BlackBox
from agents.pix_2_pix_agent import Agent
from matplotlib import pyplot as plt
env = BlackBox(30, batch_size=2, dims = 2)
agent: Agent = Agent(env.observation_space, env.dims).to(torch.device("cpu"))
agent.load_state_dict(torch.load("model.t"))
s, t = env.reset()
prev_probs = np.zeros((env.resolution, env.resolution))
prev_logits = np.zeros((env.resolution, env.resolution))
for i in range(100):
values = []
done = False
env.render()
while not done:
print("Time in env:", env.time[0].item(), "Max time:", env.T)
with torch.no_grad():
action, _, _, value = agent.get_action_and_value(s, t)
values.append(value.cpu().numpy()[0])
_, _, dones, _ = env.step(action, True)
print("Dones:", dones)
done = dones[0]
plt.close()
plt.cla()
plt.plot(values)
plt.title("Critic guess as a function of t for the last seen state")
plt.show()