forked from rainorangelemon/gnn-motion-planning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdijkstra.py
108 lines (84 loc) · 3.35 KB
/
dijkstra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
from environment import KukaEnv, MazeEnv
from environment import Kuka2Env
from torch_geometric.nn import knn_graph
from collections import defaultdict
from time import time
import pickle
from tqdm import tqdm
from torch_sparse import coalesce
INFINITY = float('inf')
def construct_graph(env, points, check_collision=True):
edge_index = knn_graph(torch.FloatTensor(points), k=5, loop=True)
edge_index = torch.cat((edge_index, edge_index.flip(0)), dim=-1)
edge_index_torch, _ = coalesce(edge_index, None, len(points), len(points))
edge_index = edge_index_torch.data.cpu().numpy().T
edge_cost = defaultdict(list)
edge_free = []
neighbors = defaultdict(list)
for i, edge in enumerate(edge_index):
if env._edge_fp(points[edge[0]], points[edge[1]]):
edge_cost[edge[1]].append(np.linalg.norm(points[edge[1]]-points[edge[0]]))
edge_free.append(True)
else:
edge_cost[edge[1]].append(INFINITY)
edge_free.append(False)
neighbors[edge[1]].append(edge[0])
return edge_cost, neighbors, edge_index, edge_free
def min_dist(q, dist):
"""
Returns the node with the smallest distance in q.
Implemented to keep the main algorithm clean.
"""
min_node = None
for node in q:
if min_node is None:
min_node = node
elif dist[node] < dist[min_node]:
min_node = node
return min_node
def dijkstra(nodes, edges, costs, source):
q = set()
dist = {}
prev = {}
for v in nodes: # initialization
dist[v] = INFINITY # unknown distance from source to v
prev[v] = INFINITY # previous node in optimal path from source
q.add(v) # all nodes initially in q (unvisited nodes)
# distance from source to source
dist[source] = 0
prev[source] = source
while q:
# node with the least distance selected first
u = min_dist(q, dist)
q.remove(u)
for index, v in enumerate(edges[u]):
alt = dist[u] + costs[u][index]
if alt < dist[v]:
# a shorter path to v has been found
dist[v] = alt
prev[v] = u
return dist, prev
if __name__ == "__main__":
data = []
n_sample = [50, 200, 1000]
# env = MazeEnv(dim=2, map_file="maze_files/mazes_4000.npz")
env = KukaEnv(map_file='maze_files/kukas_7_4000.pkl')
# env = KukaEnv(kuka_file="kuka_iiwa/model_3.urdf", map_file="maze_files/kukas_13_3000.pkl")
# env = Kuka2Env()
time0 = time()
# for n in n_sample:
for problem_index in tqdm(range(4000)):
env.init_new_problem(problem_index)
points = env.uniform_sample(n=np.random.randint(100, 400))
edge_cost, neighbors, edge_index, edge_free = construct_graph(env, points)
data.append((points, neighbors, edge_cost, edge_index, edge_free))
# dist, prev = dijkstra(list(range(len(points))), neighbors, edge_cost, 0)
# valid_goal = np.logical_and(np.array(list(dist.values())) != INFINITY, np.array(list(dist.values()))!=0)
# goal_index = np.random.choice(len(valid_goal), p=valid_goal.astype(float)/sum(valid_goal))
#
# print(time()-time0)
# print('yes')
with open('data/pkl/kuka_prm_4000.pkl', 'wb') as f:
pickle.dump(data, f, pickle.DEFAULT_PROTOCOL)