diff --git a/nngt/core/graph.py b/nngt/core/graph.py index 70b44612..5d29bee6 100644 --- a/nngt/core/graph.py +++ b/nngt/core/graph.py @@ -868,39 +868,44 @@ def get_edges(self, attribute=None, value=None, source_node=None, self.edge_id((source_node, target_node)) edges = np.array([[source_node, target_node]]) else: - # we need to use the adjacency matrix, get its subparts, - # then use the list of nodes to get the original ids back - # to do that we first convert source/target_node to lists - # (note that this has no significant speed impact) - src, tgt = None, None - - if source_node is None: - src = np.array( - [i for i in range(self.node_nb())], dtype=int) - elif is_integer(source_node): - src = np.array([source_node], dtype=int) + if source_node is None or target_node is None: + # backend-specific implementation for source or target + edges = self._get_edges(source_node=source_node, + target_node=target_node) else: - src = np.sort(source_node) + # we need to use the adjacency matrix, get its subparts, + # then use the list of nodes to get the original ids back + # to do that we first convert source/target_node to lists + # (note that this has no significant speed impact) + src, tgt = None, None + + if source_node is None: + src = np.array( + [i for i in range(self.node_nb())], dtype=int) + elif is_integer(source_node): + src = np.array([source_node], dtype=int) + else: + src = np.sort(source_node) - if target_node is None: - tgt = np.array( - [i for i in range(self.node_nb())], dtype=int) - elif is_integer(target_node): - tgt = np.array([target_node], dtype=int) - else: - tgt = np.sort(target_node) + if target_node is None: + tgt = np.array( + [i for i in range(self.node_nb())], dtype=int) + elif is_integer(target_node): + tgt = np.array([target_node], dtype=int) + else: + tgt = np.sort(target_node) - mat = self.adjacency_matrix() + mat = self.adjacency_matrix() - nnz = mat[src].tocsc()[:, tgt].nonzero() + nnz = mat[src].tocsc()[:, tgt].nonzero() - edges = np.array([src[nnz[0]], tgt[nnz[1]]], dtype=int).T + edges = np.array([src[nnz[0]], tgt[nnz[1]]], dtype=int).T - # remove reciprocal if graph is undirected - if not self.is_directed(): - edges.sort() + # remove reciprocal if graph is undirected + if not self.is_directed(): + edges.sort() - edges = _unique_rows(edges) + edges = _unique_rows(edges) # check attributes if attribute is None: diff --git a/nngt/core/gt_graph.py b/nngt/core/gt_graph.py index 48bc0628..178f49ab 100755 --- a/nngt/core/gt_graph.py +++ b/nngt/core/gt_graph.py @@ -465,6 +465,38 @@ def edges_array(self): return edges[order, :2] + def _get_edges(self, source_node=None, target_node=None): + g = self._graph + + edges = set() + + if source_node is not None: + if is_integer(source_node): + return g.get_out_edges(source_node) + + for s in source_node: + if g.is_directed(): + edges.update((tuple(e) for e in g.get_out_edges(s))) + else: + for e in g.get_all_edges(s): + if tuple(e[::-1]) not in edges: + edges.add(tuple(e)) + + return list(edges) + + if is_integer(target_node): + return g.get_in_edges(target_node) + + for t in target_node: + if g.is_directed(): + edges.update((tuple(e) for e in g.get_in_edges(t))) + else: + for e in g.get_all_edges(t): + if tuple(e[::-1]) not in edges: + edges.add(tuple(e)) + + return list(edges) + def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None, positions=None, groups=None): ''' diff --git a/nngt/core/ig_graph.py b/nngt/core/ig_graph.py index 72d2bbc6..b6e75fd7 100755 --- a/nngt/core/ig_graph.py +++ b/nngt/core/ig_graph.py @@ -319,6 +319,23 @@ def edges_array(self): g = self._graph return np.array([(e.source, e.target) for e in g.es], dtype=int) + def _get_edges(self, source_node=None, target_node=None): + g = self._graph + + edges = None + + if source_node is None: + if is_integer(target_node): + edges = g.es.select(_target_eq=target_node) + else: + edges = g.es.select(_target_in=target_node) + elif is_integer(source_node): + edges = g.es.select(_source_eq=source_node) + else: + edges = g.es.select(_source_in=source_node) + + return [e.tuple for e in edges] + def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None, positions=None, groups=None): ''' diff --git a/nngt/core/nngt_graph.py b/nngt/core/nngt_graph.py index 9a45b9fb..3468c34a 100644 --- a/nngt/core/nngt_graph.py +++ b/nngt/core/nngt_graph.py @@ -404,6 +404,46 @@ def edges_array(self): ''' return np.asarray(list(self._graph._unique), dtype=int) + def _get_edges(self, source_node=None, target_node=None): + g = self._graph + + edges = None + + if source_node is not None: + source_node = \ + [source_node] if is_integer(source_node) else source_node + + if g.is_directed(): + edges = [e for e in g._unique if e[0] in source_node] + else: + edges = set() + + for e in g._unique: + if e[0] in source_node or e[1] in source_node: + if e[::-1] not in edges: + edges.add(e) + + edges = list(edges) + + return edges + + target_node = \ + [target_node] if is_integer(target_node) else target_node + + if g.is_directed(): + edges = [e for e in g._unique if e[1] in target_node] + else: + edges = set() + + for e in g._unique: + if e[0] in target_node or e[1] in target_node: + if e[::-1] not in edges: + edges.add(e) + + edges = list(edges) + + return edges + def is_connected(self): raise NotImplementedError("Not available with 'nngt' backend, please " "install a graph library (networkx, igraph, " diff --git a/nngt/core/nx_graph.py b/nngt/core/nx_graph.py index 5954684b..4eaa1b55 100755 --- a/nngt/core/nx_graph.py +++ b/nngt/core/nx_graph.py @@ -390,6 +390,24 @@ def edges_array(self): return edges + def _get_edges(self, source_node=None, target_node=None): + g = self._graph + + if source_node is not None: + source_node = \ + [source_node] if is_integer(source_node) else source_node + + return list( + g.out_edges(source_node) if g.is_directed() + else g.edges(source_node)) + + target_node = \ + [target_node] if is_integer(target_node) else target_node + + return list( + g.in_edges(target_node) if g.is_directed() + else g.edges(target_node)) + def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None, positions=None, groups=None): ''' diff --git a/testing/test_basics.py b/testing/test_basics.py index 80464184..309f5313 100755 --- a/testing/test_basics.py +++ b/testing/test_basics.py @@ -588,10 +588,13 @@ def test_get_edges(): g.new_edges(edges) - assert np.array_equal(g.get_edges(source_node=[0, 1]), edges[:3]) - assert np.array_equal(g.get_edges(target_node=[0, 1]), edges[:2]) - assert np.array_equal(g.get_edges(source_node=[0, 2], target_node=[0, 1]), - [(0, 1)]) + def to_set(ee): + return {tuple(e) for e in ee} + + assert to_set(g.get_edges(source_node=[0, 1])) == to_set(edges[:3]) + assert to_set(g.get_edges(target_node=[0, 1])) == to_set(edges[:2]) + assert to_set(g.get_edges(source_node=[0, 2], + target_node=[0, 1])) == {(0, 1)} # undirected g = nngt.Graph(4, directed=False) @@ -602,14 +605,13 @@ def test_get_edges(): res = [(0, 1), (1, 2)] - assert np.array_equal(g.get_edges(source_node=[0, 1]), res) - assert np.array_equal(g.get_edges(target_node=[0, 1]), res) - assert np.array_equal(g.get_edges(source_node=[0, 2], target_node=[0, 1]), - res) + assert to_set(g.get_edges(source_node=[0, 1])) == to_set(res) + assert to_set(g.get_edges(target_node=[0, 1])) == to_set(res) + assert to_set(g.get_edges(source_node=[0, 2], + target_node=[0, 1])) == to_set(res) - assert np.array_equal(g.get_edges(source_node=0, target_node=1), [(0, 1)]) - assert np.array_equal(g.get_edges(source_node=0, target_node=[0, 1]), - [(0, 1)]) + assert to_set(g.get_edges(source_node=0, target_node=1)) == {(0, 1)} + assert to_set(g.get_edges(source_node=0, target_node=[0, 1])) == {(0, 1)} @pytest.mark.mpi_skip