Skip to content

Commit

Permalink
Bugfix: proper copy of returned attributes with graph-tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Silmathoron authored Oct 8, 2020
1 parent b68b8f8 commit 4b61aac
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nngt/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def get_weights(self, edges=None):
'''
if self.is_weighted():
if edges is None:
return np.asarray(self._eattr["weight"])
return self._eattr["weight"]
else:
if len(edges) == 0:
return np.array([])
Expand Down
2 changes: 1 addition & 1 deletion nngt/lib/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _to_np_array(data, dtype):
arr[:] = data
return arr

return np.asarray(data, dtype=dtype)
return np.array(data, dtype=dtype)


def _to_list(string):
Expand Down
51 changes: 51 additions & 0 deletions testing/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,56 @@ def test_delays():
assert np.all(np.isclose(delays, dmin + slope*distances))


@pytest.mark.mpi_skip
def test_attributes_are_copied():
''' Check that the attributes returned are a copy '''
rng = np.random.default_rng()

nnodes = 100
nedges = 1000

wghts = rng.uniform(0, 5, nedges)

g = nngt.generation.erdos_renyi(nodes=nnodes, edges=nedges, weights=wghts)

# check weights
ww = g.get_weights()

assert np.all(np.isclose(wghts, ww))

rng.shuffle(ww)

assert np.all(np.isclose(wghts, g.get_weights()))
assert not np.all(np.isclose(ww, g.get_weights()))

# check edge attribute
g.new_edge_attribute("etest", "double", values=2*ww)

etest = g.edge_attributes["etest"]

assert np.all(np.isclose(etest, 2*ww))

rng.shuffle(etest)

assert np.all(np.isclose(2*ww, g.edge_attributes["etest"]))
assert not np.all(np.isclose(2*ww, etest))

# check node attribute
vv = rng.uniform(2, 3, nnodes)

g.new_node_attribute("ntest", "double", values=vv)

ntest = g.node_attributes["ntest"]

assert np.all(np.isclose(ntest, vv))

rng.shuffle(ntest)

assert np.all(np.isclose(vv, g.node_attributes["ntest"]))
assert not np.all(np.isclose(vv, ntest))



# ---------- #
# Test suite #
# ---------- #
Expand All @@ -382,3 +432,4 @@ def test_delays():
unittest.main()
test_str_attr()
test_delays()
test_attributes_are_copied()

0 comments on commit 4b61aac

Please # to comment.