From 259096a2cb5deeebad2da81dab42dadd4ab21d6c Mon Sep 17 00:00:00 2001 From: itskalvik Date: Tue, 17 Dec 2024 08:51:43 -0600 Subject: [PATCH] Convert TSP start/end idx to nodes --- sgptools/utils/tsp.py | 52 ++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/sgptools/utils/tsp.py b/sgptools/utils/tsp.py index 3a00f11..d0d9c45 100644 --- a/sgptools/utils/tsp.py +++ b/sgptools/utils/tsp.py @@ -24,21 +24,23 @@ def run_tsp(nodes, max_dist=25, depth=1, resample=None, - start_idx=None, - end_idx=None, + start_nodes=None, + end_nodes=None, time_limit=10): """Method to run TSP/VRP with arbitrary start and end nodes, and without any distance constraint Args: - nodes (ndarray): (# nodes, n_dim); Nodes to visit + nodes (ndarray): (# nodes, ndim); Nodes to visit num_vehicles (int): Number of robots/vehicles max_dist (float): Maximum distance allowed for each path when handling mutli-robot case depth (int): Internal parameter used to track re-try recursion depth resample (int): Each solution path will be resampled to have `resample` number of points - start_idx (list): Optionl list of start node indices from which to start the solution path - end_idx (list): Optionl list of end node indices from which to start the solution path + start_nodes (ndarray): (# num_vehicles, ndim); Optionl array of start nodes from which + to start each vehicle's solution path + end_nodes (ndarray): (# num_vehicles, ndim); Optionl array of end nodes at which + to end each vehicle's solution path time_limit (int): TSP runtime time limit in seconds Returns: @@ -48,28 +50,42 @@ def run_tsp(nodes, if depth > 5: print('Warning: Max depth reached') return None, None - + + # Add the start and end nodes to the node list + if end_nodes is not None: + assert end_nodes.shape == (num_vehicles, nodes.shape[-1]), \ + "Incorrect end_nodes shape, should be (num_vehicles, ndim)!" + nodes = np.concatenate([end_nodes, nodes]) + if start_nodes is not None: + assert start_nodes.shape == (num_vehicles, nodes.shape[-1]), \ + "Incorrect start_nodes shape, should be (num_vehicles, ndim)!" + nodes = np.concatenate([start_nodes, nodes]) + # Add dummy 0 location to get arbitrary start and end node sols - if start_idx is None or end_idx is None: + if start_nodes is None or end_nodes is None: distance_mat = np.zeros((len(nodes)+1, len(nodes)+1)) distance_mat[1:, 1:] = pairwise_distances(nodes, nodes)*1e4 - trim_paths = True + trim_paths = True #shift to account for dummy node else: distance_mat = pairwise_distances(nodes, nodes)*1e4 trim_paths = False distance_mat = distance_mat.astype(int) max_dist = int(max_dist*1e4) - if start_idx is None: + # Get start and end node indices for ortools + if start_nodes is None: start_idx = [0]*num_vehicles - elif trim_paths: - start_idx = [i+1 for i in start_idx] + num_start_nodes = 0 + else: + start_idx = np.arange(num_vehicles)+int(trim_paths) + num_start_nodes = len(start_nodes) - if end_idx is None: + if end_nodes is None: end_idx = [0]*num_vehicles - elif trim_paths: - end_idx = [i+1 for i in end_idx] + else: + end_idx = np.arange(num_vehicles)+num_start_nodes+int(trim_paths) + # used by ortools def distance_callback(from_index, to_index): from_node = manager.IndexToNode(from_index) to_node = manager.IndexToNode(to_index) @@ -78,8 +94,8 @@ def distance_callback(from_index, to_index): # num_locations, num vehicles, start, end manager = pywrapcp.RoutingIndexManager(len(distance_mat), num_vehicles, - start_idx, - end_idx) + start_idx.tolist(), + end_idx.tolist()) routing = pywrapcp.RoutingModel(manager) transit_callback_index = routing.RegisterTransitCallback(distance_callback) routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index) @@ -165,11 +181,11 @@ def resample_path(waypoints, num_inducing=10): inducing points path with fixed number of waypoints Args: - waypoints (ndarray): (num_waypoints, n_dim); waypoints of path from vrp solver + waypoints (ndarray): (num_waypoints, ndim); waypoints of path from vrp solver num_inducing (int): Number of inducing points (waypoints) in the returned path Returns: - points (ndarray): (num_inducing, n_dim); Resampled path + points (ndarray): (num_inducing, ndim); Resampled path """ ndim = np.shape(waypoints)[-1] if not (ndim==2 or ndim==3):