Skip to content

Commit

Permalink
Convert TSP start/end idx to nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 17, 2024
1 parent 63f6591 commit 259096a
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions sgptools/utils/tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@ def run_tsp(nodes,
max_dist=25,
depth=1,
resample=None,
start_idx=None,
end_idx=None,
start_nodes=None,
end_nodes=None,
time_limit=10):
"""Method to run TSP/VRP with arbitrary start and end nodes,
and without any distance constraint
Args:
nodes (ndarray): (# nodes, n_dim); Nodes to visit
nodes (ndarray): (# nodes, ndim); Nodes to visit
num_vehicles (int): Number of robots/vehicles
max_dist (float): Maximum distance allowed for each path when handling mutli-robot case
depth (int): Internal parameter used to track re-try recursion depth
resample (int): Each solution path will be resampled to have
`resample` number of points
start_idx (list): Optionl list of start node indices from which to start the solution path
end_idx (list): Optionl list of end node indices from which to start the solution path
start_nodes (ndarray): (# num_vehicles, ndim); Optionl array of start nodes from which
to start each vehicle's solution path
end_nodes (ndarray): (# num_vehicles, ndim); Optionl array of end nodes at which
to end each vehicle's solution path
time_limit (int): TSP runtime time limit in seconds
Returns:
Expand All @@ -48,28 +50,42 @@ def run_tsp(nodes,
if depth > 5:
print('Warning: Max depth reached')
return None, None


# Add the start and end nodes to the node list
if end_nodes is not None:
assert end_nodes.shape == (num_vehicles, nodes.shape[-1]), \
"Incorrect end_nodes shape, should be (num_vehicles, ndim)!"
nodes = np.concatenate([end_nodes, nodes])
if start_nodes is not None:
assert start_nodes.shape == (num_vehicles, nodes.shape[-1]), \
"Incorrect start_nodes shape, should be (num_vehicles, ndim)!"
nodes = np.concatenate([start_nodes, nodes])

# Add dummy 0 location to get arbitrary start and end node sols
if start_idx is None or end_idx is None:
if start_nodes is None or end_nodes is None:
distance_mat = np.zeros((len(nodes)+1, len(nodes)+1))
distance_mat[1:, 1:] = pairwise_distances(nodes, nodes)*1e4
trim_paths = True
trim_paths = True #shift to account for dummy node
else:
distance_mat = pairwise_distances(nodes, nodes)*1e4
trim_paths = False
distance_mat = distance_mat.astype(int)
max_dist = int(max_dist*1e4)

if start_idx is None:
# Get start and end node indices for ortools
if start_nodes is None:
start_idx = [0]*num_vehicles
elif trim_paths:
start_idx = [i+1 for i in start_idx]
num_start_nodes = 0
else:
start_idx = np.arange(num_vehicles)+int(trim_paths)
num_start_nodes = len(start_nodes)

if end_idx is None:
if end_nodes is None:
end_idx = [0]*num_vehicles
elif trim_paths:
end_idx = [i+1 for i in end_idx]
else:
end_idx = np.arange(num_vehicles)+num_start_nodes+int(trim_paths)

# used by ortools
def distance_callback(from_index, to_index):
from_node = manager.IndexToNode(from_index)
to_node = manager.IndexToNode(to_index)
Expand All @@ -78,8 +94,8 @@ def distance_callback(from_index, to_index):
# num_locations, num vehicles, start, end
manager = pywrapcp.RoutingIndexManager(len(distance_mat),
num_vehicles,
start_idx,
end_idx)
start_idx.tolist(),
end_idx.tolist())
routing = pywrapcp.RoutingModel(manager)
transit_callback_index = routing.RegisterTransitCallback(distance_callback)
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
Expand Down Expand Up @@ -165,11 +181,11 @@ def resample_path(waypoints, num_inducing=10):
inducing points path with fixed number of waypoints
Args:
waypoints (ndarray): (num_waypoints, n_dim); waypoints of path from vrp solver
waypoints (ndarray): (num_waypoints, ndim); waypoints of path from vrp solver
num_inducing (int): Number of inducing points (waypoints) in the returned path
Returns:
points (ndarray): (num_inducing, n_dim); Resampled path
points (ndarray): (num_inducing, ndim); Resampled path
"""
ndim = np.shape(waypoints)[-1]
if not (ndim==2 or ndim==3):
Expand Down

0 comments on commit 259096a

Please # to comment.