Skip to content

Commit

Permalink
Speedup get_edges (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
Silmathoron authored Oct 15, 2020
1 parent 294bb03 commit 87387ca
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 37 deletions.
57 changes: 31 additions & 26 deletions nngt/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,39 +868,44 @@ def get_edges(self, attribute=None, value=None, source_node=None,
self.edge_id((source_node, target_node))
edges = np.array([[source_node, target_node]])
else:
# we need to use the adjacency matrix, get its subparts,
# then use the list of nodes to get the original ids back
# to do that we first convert source/target_node to lists
# (note that this has no significant speed impact)
src, tgt = None, None

if source_node is None:
src = np.array(
[i for i in range(self.node_nb())], dtype=int)
elif is_integer(source_node):
src = np.array([source_node], dtype=int)
if source_node is None or target_node is None:
# backend-specific implementation for source or target
edges = self._get_edges(source_node=source_node,
target_node=target_node)
else:
src = np.sort(source_node)
# we need to use the adjacency matrix, get its subparts,
# then use the list of nodes to get the original ids back
# to do that we first convert source/target_node to lists
# (note that this has no significant speed impact)
src, tgt = None, None

if source_node is None:
src = np.array(
[i for i in range(self.node_nb())], dtype=int)
elif is_integer(source_node):
src = np.array([source_node], dtype=int)
else:
src = np.sort(source_node)

if target_node is None:
tgt = np.array(
[i for i in range(self.node_nb())], dtype=int)
elif is_integer(target_node):
tgt = np.array([target_node], dtype=int)
else:
tgt = np.sort(target_node)
if target_node is None:
tgt = np.array(
[i for i in range(self.node_nb())], dtype=int)
elif is_integer(target_node):
tgt = np.array([target_node], dtype=int)
else:
tgt = np.sort(target_node)

mat = self.adjacency_matrix()
mat = self.adjacency_matrix()

nnz = mat[src].tocsc()[:, tgt].nonzero()
nnz = mat[src].tocsc()[:, tgt].nonzero()

edges = np.array([src[nnz[0]], tgt[nnz[1]]], dtype=int).T
edges = np.array([src[nnz[0]], tgt[nnz[1]]], dtype=int).T

# remove reciprocal if graph is undirected
if not self.is_directed():
edges.sort()
# remove reciprocal if graph is undirected
if not self.is_directed():
edges.sort()

edges = _unique_rows(edges)
edges = _unique_rows(edges)

# check attributes
if attribute is None:
Expand Down
32 changes: 32 additions & 0 deletions nngt/core/gt_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,38 @@ def edges_array(self):

return edges[order, :2]

def _get_edges(self, source_node=None, target_node=None):
g = self._graph

edges = set()

if source_node is not None:
if is_integer(source_node):
return g.get_out_edges(source_node)

for s in source_node:
if g.is_directed():
edges.update((tuple(e) for e in g.get_out_edges(s)))
else:
for e in g.get_all_edges(s):
if tuple(e[::-1]) not in edges:
edges.add(tuple(e))

return list(edges)

if is_integer(target_node):
return g.get_in_edges(target_node)

for t in target_node:
if g.is_directed():
edges.update((tuple(e) for e in g.get_in_edges(t)))
else:
for e in g.get_all_edges(t):
if tuple(e[::-1]) not in edges:
edges.add(tuple(e))

return list(edges)

def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None,
positions=None, groups=None):
'''
Expand Down
17 changes: 17 additions & 0 deletions nngt/core/ig_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,23 @@ def edges_array(self):
g = self._graph
return np.array([(e.source, e.target) for e in g.es], dtype=int)

def _get_edges(self, source_node=None, target_node=None):
g = self._graph

edges = None

if source_node is None:
if is_integer(target_node):
edges = g.es.select(_target_eq=target_node)
else:
edges = g.es.select(_target_in=target_node)
elif is_integer(source_node):
edges = g.es.select(_source_eq=source_node)
else:
edges = g.es.select(_source_in=source_node)

return [e.tuple for e in edges]

def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None,
positions=None, groups=None):
'''
Expand Down
40 changes: 40 additions & 0 deletions nngt/core/nngt_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,46 @@ def edges_array(self):
'''
return np.asarray(list(self._graph._unique), dtype=int)

def _get_edges(self, source_node=None, target_node=None):
g = self._graph

edges = None

if source_node is not None:
source_node = \
[source_node] if is_integer(source_node) else source_node

if g.is_directed():
edges = [e for e in g._unique if e[0] in source_node]
else:
edges = set()

for e in g._unique:
if e[0] in source_node or e[1] in source_node:
if e[::-1] not in edges:
edges.add(e)

edges = list(edges)

return edges

target_node = \
[target_node] if is_integer(target_node) else target_node

if g.is_directed():
edges = [e for e in g._unique if e[1] in target_node]
else:
edges = set()

for e in g._unique:
if e[0] in target_node or e[1] in target_node:
if e[::-1] not in edges:
edges.add(e)

edges = list(edges)

return edges

def is_connected(self):
raise NotImplementedError("Not available with 'nngt' backend, please "
"install a graph library (networkx, igraph, "
Expand Down
18 changes: 18 additions & 0 deletions nngt/core/nx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,24 @@ def edges_array(self):

return edges

def _get_edges(self, source_node=None, target_node=None):
g = self._graph

if source_node is not None:
source_node = \
[source_node] if is_integer(source_node) else source_node

return list(
g.out_edges(source_node) if g.is_directed()
else g.edges(source_node))

target_node = \
[target_node] if is_integer(target_node) else target_node

return list(
g.in_edges(target_node) if g.is_directed()
else g.edges(target_node))

def new_node(self, n=1, neuron_type=1, attributes=None, value_types=None,
positions=None, groups=None):
'''
Expand Down
24 changes: 13 additions & 11 deletions testing/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,13 @@ def test_get_edges():

g.new_edges(edges)

assert np.array_equal(g.get_edges(source_node=[0, 1]), edges[:3])
assert np.array_equal(g.get_edges(target_node=[0, 1]), edges[:2])
assert np.array_equal(g.get_edges(source_node=[0, 2], target_node=[0, 1]),
[(0, 1)])
def to_set(ee):
return {tuple(e) for e in ee}

assert to_set(g.get_edges(source_node=[0, 1])) == to_set(edges[:3])
assert to_set(g.get_edges(target_node=[0, 1])) == to_set(edges[:2])
assert to_set(g.get_edges(source_node=[0, 2],
target_node=[0, 1])) == {(0, 1)}

# undirected
g = nngt.Graph(4, directed=False)
Expand All @@ -602,14 +605,13 @@ def test_get_edges():

res = [(0, 1), (1, 2)]

assert np.array_equal(g.get_edges(source_node=[0, 1]), res)
assert np.array_equal(g.get_edges(target_node=[0, 1]), res)
assert np.array_equal(g.get_edges(source_node=[0, 2], target_node=[0, 1]),
res)
assert to_set(g.get_edges(source_node=[0, 1])) == to_set(res)
assert to_set(g.get_edges(target_node=[0, 1])) == to_set(res)
assert to_set(g.get_edges(source_node=[0, 2],
target_node=[0, 1])) == to_set(res)

assert np.array_equal(g.get_edges(source_node=0, target_node=1), [(0, 1)])
assert np.array_equal(g.get_edges(source_node=0, target_node=[0, 1]),
[(0, 1)])
assert to_set(g.get_edges(source_node=0, target_node=1)) == {(0, 1)}
assert to_set(g.get_edges(source_node=0, target_node=[0, 1])) == {(0, 1)}


@pytest.mark.mpi_skip
Expand Down

0 comments on commit 87387ca

Please # to comment.