Skip to content

Commit 52a41c7

Browse files
zietzmdhimmel
authored andcommitted
permute_pair_list: default to inplace=False
Closes #28 Adds test for permute_pair_list. Fixes unrelated filename typo.
1 parent 1a1da50 commit 52a41c7

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

hetio/permute.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def permute_graph(graph, multiplier=10, seed=0, metaedge_to_excluded=dict(), log
8080
return permuted_graph, all_stats
8181

8282

83-
def permute_pair_list(pair_list, directed=False, multiplier=10, excluded_pair_set=set(), seed=0, log=False):
83+
def permute_pair_list(pair_list, directed=False, multiplier=10, excluded_pair_set=set(),
84+
seed=0, log=False, inplace=False):
8485
"""
8586
Permute edges (of a single type) in a graph according to the XSwap function
8687
described in https://doi.org/f3mn58. This method selects two edges and
@@ -116,6 +117,8 @@ def permute_pair_list(pair_list, directed=False, multiplier=10, excluded_pair_se
116117
Seed to initialize Python random number generator.
117118
log : bool
118119
Whether to log diagnostic INFO via python's logging module.
120+
inplace : bool
121+
Whether to modify the edge list in place.
119122
120123
Returns
121124
-------
@@ -128,6 +131,9 @@ def permute_pair_list(pair_list, directed=False, multiplier=10, excluded_pair_se
128131
"""
129132
random.seed(seed)
130133

134+
if not inplace:
135+
pair_list = pair_list.copy()
136+
131137
pair_set = set(pair_list)
132138
assert len(pair_set) == len(pair_list)
133139

File renamed without changes.

test/permute_test.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
import hetio.permute
4+
5+
6+
@pytest.mark.parametrize('edges,inplace', [
7+
([(0, 0), (1, 1), (1, 2), (2, 3)], True),
8+
([(0, 0), (1, 1), (1, 2), (2, 3)], False),
9+
])
10+
def test_permute_inplace(edges, inplace):
11+
old_edges = edges.copy()
12+
new_edges, stats = hetio.permute.permute_pair_list(edges, inplace=inplace)
13+
assert old_edges != new_edges
14+
15+
if inplace:
16+
assert edges == new_edges
17+
else:
18+
assert edges != new_edges

0 commit comments

Comments
 (0)