-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathppi_task.py
281 lines (238 loc) · 13.6 KB
/
ppi_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from collections import namedtuple
from typing import Any, Dict, Iterator, List, Iterable
import tensorflow as tf
import numpy as np
from dpu_utils.utils import RichPath
from models.self_attention import SelfAttention
from .sparse_graph_task import Sparse_Graph_Task, DataFold, MinibatchData
from utils import micro_f1
GraphSample = namedtuple('GraphSample', ['adjacency_lists',
'type_to_node_to_num_incoming_edges',
'node_features',
'node_labels',
])
class PPI_Task(Sparse_Graph_Task):
@classmethod
def default_params(cls):
params = super().default_params()
params.update({
'add_self_loop_edges': True,
'tie_fwd_bkwd_edges': False,
'out_layer_dropout_keep_prob': 1.0,
})
return params
@staticmethod
def name() -> str:
return "PPI"
@staticmethod
def default_data_path() -> str:
return "data/ppi"
def __init__(self, params: Dict[str, Any]):
super().__init__(params)
# Things that will be filled once we load data:
self.__num_edge_types = 0
self.__initial_node_feature_size = 0
self.__num_labels = 0
def get_metadata(self) -> Dict[str, Any]:
metadata = super().get_metadata()
metadata['num_edge_types'] = self.__num_edge_types
metadata['initial_node_feature_size'] = self.__initial_node_feature_size
metadata['num_labels'] = self.__num_labels
return metadata
def restore_from_metadata(self, metadata: Dict[str, Any]) -> None:
super().restore_from_metadata(metadata)
self.__num_edge_types = metadata['num_edge_types']
self.__initial_node_feature_size = metadata['initial_node_feature_size']
self.__num_labels = metadata['num_labels']
@property
def num_edge_types(self) -> int:
return self.__num_edge_types
@property
def initial_node_feature_size(self) -> int:
return self.__initial_node_feature_size
# -------------------- Data Loading --------------------
def load_data(self, path: RichPath) -> None:
# Data in format as downloaded from https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ppi.zip
self._loaded_data[DataFold.TRAIN] = self.__load_data(path, DataFold.TRAIN)
self._loaded_data[DataFold.VALIDATION] = self.__load_data(path, DataFold.VALIDATION)
def load_eval_data_from_path(self, path: RichPath) -> Iterable[Any]:
return self.__load_data(path, DataFold.TEST)
def __load_data(self, data_dir: RichPath, data_fold: DataFold) -> List[GraphSample]:
if data_fold == DataFold.TRAIN:
data_name = "train"
elif data_fold == DataFold.VALIDATION:
data_name = "valid"
elif data_fold == DataFold.TEST:
data_name = "test"
else:
raise ValueError("Unknown data fold '%s'" % str(data_fold))
print(" Loading PPI %s data from %s." % (data_name, data_dir))
graph_json_data = data_dir.join("%s_graph.json" % data_name).read_by_file_suffix()
node_to_features = data_dir.join("%s_feats.npy" % data_name).read_by_file_suffix()
node_to_labels = data_dir.join("%s_labels.npy" % data_name).read_by_file_suffix()
node_to_graph_id = data_dir.join("%s_graph_id.npy" % data_name).read_by_file_suffix()
self.__initial_node_feature_size = node_to_features.shape[-1]
self.__num_labels = node_to_labels.shape[-1]
# We read in all the data in two steps:
# (1) Read features, labels and insert self-loop edges (edge type 0).
# Implicitly, this gives us the number of nodes per graph.
# (2) Read all edges, and shift them so that each graph starts with node 0.
fwd_edge_type = 0
self.__num_edge_types = 1
if self.params['add_self_loop_edges']:
self_loop_edge_type = self.__num_edge_types
self.__num_edge_types += 1
if not self.params['tie_fwd_bkwd_edges']:
bkwd_edge_type = self.__num_edge_types
self.__num_edge_types += 1
graph_id_to_graph_data = {} # type: Dict[int, GraphSample]
graph_id_to_node_offset = {}
num_total_nodes = node_to_features.shape[0]
for node_id in range(num_total_nodes):
graph_id = node_to_graph_id[node_id]
# In case we are entering a new graph, note its ID, so that we can normalise everything to start at 0
if graph_id not in graph_id_to_graph_data:
graph_id_to_graph_data[graph_id] = \
GraphSample(adjacency_lists=[[] for _ in range(self.__num_edge_types)],
type_to_node_to_num_incoming_edges=[[] for _ in range(self.__num_edge_types)],
node_features=[],
node_labels=[])
graph_id_to_node_offset[graph_id] = node_id
cur_graph_data = graph_id_to_graph_data[graph_id]
cur_graph_data.node_features.append(node_to_features[node_id])
cur_graph_data.node_labels.append(node_to_labels[node_id])
shifted_node_id = node_id - graph_id_to_node_offset[graph_id]
if self.params['add_self_loop_edges']:
cur_graph_data.adjacency_lists[self_loop_edge_type].append((shifted_node_id, shifted_node_id))
cur_graph_data.type_to_node_to_num_incoming_edges[self_loop_edge_type].append(1)
# Prepare reading of the edges by setting counters to 0:
for graph_data in graph_id_to_graph_data.values():
num_graph_nodes = len(graph_data.node_features)
graph_data.type_to_node_to_num_incoming_edges[fwd_edge_type] = np.zeros([num_graph_nodes], np.int32)
if not self.params['tie_fwd_bkwd_edges']:
graph_data.type_to_node_to_num_incoming_edges[bkwd_edge_type] = np.zeros([num_graph_nodes], np.int32)
for edge_info in graph_json_data['links']:
src_node, tgt_node = edge_info['source'], edge_info['target']
# First, shift node IDs so that each graph starts at node 0:
graph_id = node_to_graph_id[src_node]
graph_node_offset = graph_id_to_node_offset[graph_id]
src_node, tgt_node = src_node - graph_node_offset, tgt_node - graph_node_offset
cur_graph_data = graph_id_to_graph_data[graph_id]
cur_graph_data.adjacency_lists[fwd_edge_type].append((src_node, tgt_node))
cur_graph_data.type_to_node_to_num_incoming_edges[fwd_edge_type][tgt_node] += 1
if not self.params['tie_fwd_bkwd_edges']:
cur_graph_data.adjacency_lists[bkwd_edge_type].append((tgt_node, src_node))
cur_graph_data.type_to_node_to_num_incoming_edges[bkwd_edge_type][src_node] += 1
final_graphs = []
for graph_data in graph_id_to_graph_data.values():
# numpy-ize:
adj_lists = []
for edge_type_idx in range(self.__num_edge_types):
adj_lists.append(np.array(graph_data.adjacency_lists[edge_type_idx]))
final_graphs.append(
GraphSample(adjacency_lists=adj_lists,
type_to_node_to_num_incoming_edges=np.array(graph_data.type_to_node_to_num_incoming_edges),
node_features=np.array(graph_data.node_features),
node_labels=np.array(graph_data.node_labels)))
return final_graphs
# -------------------- Model Construction --------------------
def make_task_input_model(self,
placeholders: Dict[str, tf.Tensor],
model_ops: Dict[str, tf.Tensor],
) -> None:
super().make_task_input_model(placeholders=placeholders, model_ops=model_ops)
placeholders['graph_to_nodes'] = \
tf.placeholder(dtype=tf.int32, shape=[None, None], name='graph_to_nodes') # (G, V)
def make_task_output_model(self,
placeholders: Dict[str, tf.Tensor],
model_ops: Dict[str, tf.Tensor],
) -> None:
placeholders['graph_nodes_list'] = \
tf.placeholder(dtype=tf.int32, shape=[None], name='graph_nodes_list')
placeholders['target_labels'] = \
tf.placeholder(dtype=tf.float32, shape=[None, self.__num_labels], name='target_labels')
placeholders['out_layer_dropout_keep_prob'] = \
tf.placeholder(dtype=tf.float32, shape=[], name='out_layer_dropout_keep_prob')
final_node_representations = model_ops['final_node_representations']
final_node_repr_size = final_node_representations.shape.as_list()[-1]
per_node_logits = \
tf.keras.layers.Dense(units=self.__num_labels,
use_bias=True,
)(final_node_representations)
losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=per_node_logits,
labels=placeholders['target_labels'])
total_loss = tf.reduce_sum(losses)
# Compute loss as average per node (to account for changing number of nodes per batch):
num_nodes_in_batch = tf.shape(placeholders['target_labels'])[0]
f1_score = micro_f1(per_node_logits, placeholders['target_labels'])
tf.summary.scalar("Micro F1", f1_score)
model_ops['task_metrics'] = {
'loss': total_loss / tf.cast(num_nodes_in_batch, tf.float32),
'total_loss': total_loss,
'f1_score': f1_score,
}
# -------------------- Minibatching and training loop --------------------
def make_minibatch_iterator(self,
data: Iterable[Any],
data_fold: DataFold,
model_placeholders: Dict[str, tf.Tensor],
max_nodes_per_batch: int) \
-> Iterator[MinibatchData]:
if data_fold == DataFold.TRAIN:
np.random.shuffle(data)
out_layer_dropout_keep_prob = self.params['out_layer_dropout_keep_prob']
else:
out_layer_dropout_keep_prob = 1.0
# Pack until we cannot fit more graphs in the batch
num_graphs = 0
while num_graphs < len(data):
num_graphs_in_batch = 0
batch_node_features = [] # type: List[np.ndarray]
batch_node_labels = []
batch_adjacency_lists = [[] for _ in range(self.num_edge_types)] # type: List[List[np.ndarray]]
batch_type_to_num_incoming_edges = []
batch_graph_nodes_list = []
graph_to_nodes = []
node_offset = 0
while num_graphs < len(data) and node_offset + len(data[num_graphs].node_features) < max_nodes_per_batch:
cur_graph = data[num_graphs]
num_nodes_in_graph = len(data[num_graphs].node_features)
batch_node_features.extend(cur_graph.node_features)
batch_graph_nodes_list.append(np.full(shape=[num_nodes_in_graph],
fill_value=num_graphs_in_batch,
dtype=np.int32))
for i in range(self.num_edge_types):
batch_adjacency_lists[i].append(cur_graph.adjacency_lists[i] + node_offset)
batch_type_to_num_incoming_edges.append(cur_graph.type_to_node_to_num_incoming_edges)
batch_node_labels.append(cur_graph.node_labels)
graph_to_nodes.append([i + node_offset for i in range(num_nodes_in_graph)])
num_graphs += 1
num_graphs_in_batch += 1
node_offset += num_nodes_in_graph
batch_feed_dict = {
model_placeholders['initial_node_features']: np.array(batch_node_features),
model_placeholders['type_to_num_incoming_edges']: np.concatenate(batch_type_to_num_incoming_edges, axis=1),
model_placeholders['graph_nodes_list']: np.concatenate(batch_graph_nodes_list),
model_placeholders['target_labels']: np.concatenate(batch_node_labels, axis=0),
model_placeholders['out_layer_dropout_keep_prob']: out_layer_dropout_keep_prob,
model_placeholders['graph_to_nodes']: Sparse_Graph_Task.pad_lists(graph_to_nodes, value=-1),
}
# Merge adjacency lists:
num_edges = 0
for i in range(self.num_edge_types):
if len(batch_adjacency_lists[i]) > 0:
adj_list = np.concatenate(batch_adjacency_lists[i])
else:
adj_list = np.zeros((0, 2), dtype=np.int32)
num_edges += adj_list.shape[0]
batch_feed_dict[model_placeholders['adjacency_lists'][i]] = adj_list
yield MinibatchData(feed_dict=batch_feed_dict,
num_graphs=num_graphs_in_batch,
num_nodes=node_offset,
num_edges=num_edges)
def early_stopping_metric(self, task_metric_results: List[Dict[str, np.ndarray]], num_graphs: int) -> float:
# Early stopping based on average loss:
return np.sum([m['total_loss'] for m in task_metric_results]) / num_graphs
def pretty_print_epoch_task_metrics(self, task_metric_results: List[Dict[str, np.ndarray]], num_graphs: int) -> str:
avg_microf1 = np.average([m['f1_score'] for m in task_metric_results])
return "Avg MicroF1: %.3f" % (avg_microf1,)