-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_gsp.py
135 lines (109 loc) · 5.02 KB
/
test_gsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import pybullet as p
import pybullet_data
import time
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
import torch
from sim_utils import SymFetch
from models.gsp_net import GSPNet
from models.goal_recognizer_net import GoalReconizerNet
from mpl_toolkits.axes_grid1 import ImageGrid
import imageio
import torch
from r3m import load_r3m
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
r3m = load_r3m("resnet50") # resnet18, resnet34
r3m.eval()
r3m.to(device)
if __name__ == '__main__':
goal_images = []
current_images = []
with torch.no_grad():
fetch = SymFetch(random_init=False)
fetch.generate_blocks(random_number=False, random_color=False, random_pos=False)
state_dim = 2048
joint_state_dim = 7 + 3 # 7 joints + 1 gripper + 2048 for R3M embedding
action_dim = 7 + 1 # 7 joint position changes + gripper action
model = GSPNet(state_dim, joint_state_dim, action_dim, num_actions=5)
# model.load_state_dict(torch.load('GSP_model.pt'))
model.load_state_dict(torch.load('models/GSP_model_multistep.pt'))
model.eval()
gr = GoalReconizerNet(state_dim, 0)
gr.load_state_dict(torch.load('models/GoalRecognizer_net.pt'))
gr.eval()
goals = np.load('goal.npy')
last_action = torch.zeros(1, action_dim-1).to(device)
dist = 1
k = 0
pos = np.array([0.7, 0.0, 0.5])
while dist > 0.095 and k < 40:
dist = fetch.move_to(pos)
for _ in range(24):
fetch.stepSimulation()
time.sleep(1/240)
k += 1
last_joint_state = torch.from_numpy(fetch.get_joint_angles()).to(device).float()
# fig = plt.figure()
# grid = ImageGrid(fig, 111, nrows_ncols=(2,len(goals)), axes_pad=0.01)
output_image = None
i = 0
try:
for row_idx, goal_idx in enumerate(range(len(goals))):
goal = r3m(torch.from_numpy(goals[goal_idx]['r3m']).to(device).permute(2,0,1).reshape(-1, 3, 224, 224))
goal_pos = goals[goal_idx]['x']
print('----goal {}-------'.format(goal_idx))
cv2.destroyAllWindows()
img = goals[goal_idx]['r3m'].astype(np.uint8)
# grid[row_idx].imshow(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow('goal', img)
cv2.waitKey(1000)
gr_score = 0
i_max = i + 1000
while gr_score < 0.8 and i < i_max:
if i%24==0:
im = torch.tensor(fetch.get_image(True))
# Set inputs for policy: State (joint positions and velocities) and R3M embedding
features = r3m(im.permute(2,0,1).reshape(-1, 3, 224, 224))
# Get current state
joint_state = torch.from_numpy(fetch.get_joint_angles()).to(device).float()
full_joint_state = torch.cat((joint_state, torch.tensor(fetch.get_gripper_state()).to(device))).view(1,-1).float()
last_action = joint_state - last_joint_state
# Get output from policy
output = model(features, full_joint_state, goal, last_action.view(1,-1))[0]
# last_action = output[0,:7]
output = output.detach().cpu().numpy()
# Set robot commands from policy
pos = output[0,:7] + fetch.get_joint_angles()
open_gripper = bool(round(output[0,-1]))
fetch.set_joint_angles(pos)
fetch.set_gripper(open=open_gripper)
last_joint_state = joint_state
gr_score = gr(features, goal)
dist = ((features - goal)**2).sum().sqrt()
print(i, gr_score.item(), dist.item())
# stacked_img = np.concatenate((img, fetch.get_image(resize=True)), axis=0)
# goal_images.append(img)
# current_images.append(fetch.get_image(resize=True))
fetch.stepSimulation()
i+=1
time.sleep(1./240.)
# grid[row_idx + len(goals)].imshow(fetch.get_image(resize=True))
# if output_image is None:
# output_image = stacked_img
# else:
# output_image = np.concatenate((output_image, stacked_img), axis=1)
time.sleep(1)
except KeyboardInterrupt as e:
p.disconnect()
# print(len(goal_images))
# imageio.mimsave('goal_images.gif', goal_images)#, format="GIF", duration=len(images)/10)
# imageio.mimsave('current_images.gif', current_images)
# cv2.imshow('output', output_image)
# cv2.imwrite('gsp_goals.png', output_image)
# cv2.waitKey(0)