-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
147 lines (117 loc) · 4.8 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import pandas as pd
import pickle as pkl
from graph_tool.all import load_graph
from glob import glob
from tqdm import tqdm
from utils import edges2graph
from infer_time import fill_missing_time
from scipy.stats import kendalltau
from feasibility import is_arborescence
def edge_order_accuracy(pred_edges, infection_times):
n_correct_edges = sum(1
for u, v in pred_edges
if infection_times[u] <= infection_times[v])
return n_correct_edges / len(pred_edges)
# @profile
def evaluate_performance(g, root, source, pred_edges, obs_nodes, infection_times,
true_edges):
# change -1 to infinity (for order comparison)
# infection_times[infection_times == -1] = float('inf')
true_nodes = {i for e in true_edges for i in e}
pred_nodes = {i for e in pred_edges for i in e}
# mmc = matthews_corrcoef(true_labels, inferred_labels)
# n_prec = precision_score(true_labels, inferred_labels)
# n_rec = recall_score(true_labels, inferred_labels)
common_nodes = true_nodes.intersection(pred_nodes)
n_prec = len(common_nodes) / len(pred_nodes)
n_rec = len(common_nodes) / len(true_nodes)
obj = len(pred_edges)
pred_tree = edges2graph(g, pred_edges)
root = next(v
for v in pred_tree.vertices()
if v.in_degree() == 0 and v.out_degree() > 0)
assert is_arborescence(pred_tree)
pred_times = fill_missing_time(g, pred_tree, root, obs_nodes, infection_times, debug=False)
# pred_times = np.asarray(pred_times, dtype=float)
# pred_times[pred_times == -1] = float('inf')
# consider only predicted nodes that are actual infections
nodes = list(common_nodes)
rank_corr = kendalltau(pred_times[nodes], infection_times[nodes])[0]
common_edges = set(pred_edges).intersection(true_edges)
e_prec = len(common_edges) / len(pred_edges)
e_rec = len(common_edges) / len(true_edges)
# order accuracy on edge
edges = [e for e in pred_edges
if (e[0] in common_nodes and
e[1] in common_nodes)]
if len(edges) > 0:
order_accuracy = edge_order_accuracy(edges, infection_times)
else:
order_accuracy = 0.0
# leaves = get_leaves(true_tree)
# true_tree_paths = get_paths(true_tree, source, leaves)
# corrs = get_rank_corrs(pred_tree, root, true_tree_paths, debug=False)
# return (n_prec, n_rec, obj, cosine_sim, e_prec, e_rec, np.mean(corrs))
return (n_prec, n_rec, obj, e_prec, e_rec, rank_corr, order_accuracy)
def evaluate_from_result_dir(g, result_dir, qs):
for q in tqdm(qs):
rows = []
for p in glob(result_dir + "/{}/*.pkl".format(q)):
# print(p)
# TODO: add root
infection_times, source, obs_nodes, true_edges, pred_edges = pkl.load(open(p, 'rb'))
root = None
try:
scores = evaluate_performance(g, root, source, pred_edges, obs_nodes,
infection_times, true_edges)
except AssertionError:
import sys
print(p)
print(sys.exc_info()[0])
raise
rows.append(scores)
path = result_dir + "/{}.pkl".format(q)
if rows:
df = pd.DataFrame(rows, columns=['n.prec', 'n.rec',
'obj',
'e.prec', 'e.rec',
'rank-corr',
'order accuracy'
])
yield (path, df)
else:
if os.path.exists(path):
os.remove(path)
yield None
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gtype', required=True)
parser.add_argument('-l', '--model', required=True)
parser.add_argument('-m', '--method', required=True)
parser.add_argument('-q', '--qs', type=float, nargs="+")
parser.add_argument('-o', '--output_dir', default='outputs/paper_experiment')
args = parser.parse_args()
gtype = args.gtype
qs = args.qs
method = args.method
model = args.model
output_dir = args.output_dir
print("""graph: {}
model: {}
qs: {}
method: {}""".format(gtype, model, qs, method))
result_dir = "{output_dir}/{gtype}/{model}/{method}/qs".format(
output_dir=output_dir,
gtype=gtype,
model=model,
method=method)
g = load_graph('data/{}/graph.gt'.format(gtype))
for r in evaluate_from_result_dir(g, result_dir, qs):
if r:
path, df = r
print('writing to {}'.format(path))
df.describe().to_pickle(path)
else:
print('not result.')