-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathppi.py
104 lines (72 loc) · 2.72 KB
/
ppi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
# Set environment variables for JAX memory limits
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
import keras
import dgl
import gat.models
import src.optimizers
load_last_weights = False
continue_training = False
initial_epoch = 0
val_dataset = dgl.data.PPIDataset(mode='valid')
test_dataset = dgl.data.PPIDataset(mode='test')
train_dataset = dgl.data.PPIDataset(mode='train')
val_graphs = []
test_graphs = []
train_graphs = []
val_labels = []
test_labels = []
train_labels = []
mode_datasets = [val_dataset, test_dataset, train_dataset]
mode_graphs = [val_graphs, test_graphs, train_graphs]
mode_labels = [val_labels, test_labels, train_labels]
for i, dataset in enumerate(mode_datasets):
for graph in dataset:
# get edges
edges = keras.ops.transpose(keras.ops.convert_to_tensor(graph.edges(), dtype='int32'))
# get node features
features = keras.ops.convert_to_tensor(graph.ndata['feat'])
# get ground-truth labels
mode_labels[i].append(keras.ops.convert_to_tensor(graph.ndata['label']))
mode_graphs[i].append((features, edges))
# train and evaluate
# define hyper-parameters
output_dim = int(keras.ops.shape(train_labels[0])[-1])
num_epochs = 10000
#batch_size = 1 # number of graphs per batch
learning_rate = 0.001
keras.utils.set_random_seed(1234)
random_gen = keras.random.SeedGenerator(1234)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
optimizer = src.optimizers.Adan(learning_rate)
accuracy_fn = keras.metrics.BinaryAccuracy(name='acc')
f1_fn = keras.metrics.F1Score(average='micro', threshold=0.5, name='f1_score')
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_f1_score',
patience=200,
mode='max',
restore_best_weights=True
)
# build model
gat_model = gat.models.GraphAttentionNetworkInductive(output_dim, random_gen=random_gen)
# compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn, f1_fn])
weightsfile = './weights/ppi.weights.h5'
if load_last_weights and os.path.isfile(weightsfile):
gat_model(train_graphs[0]) # force model building
gat_model.load_weights(weightsfile)
val_generator = gat.models.DataGenerator(val_graphs, val_labels)
test_generator = gat.models.DataGenerator(test_graphs, test_labels)
train_generator = gat.models.DataGenerator(train_graphs, train_labels)
if not load_last_weights or continue_training:
gat_model.fit(
train_generator,
validation_data=val_generator,
epochs=num_epochs,
callbacks=[early_stopping],
verbose=2,
initial_epoch=initial_epoch,
)
test_loss, test_accuracy, test_f1 = gat_model.evaluate(test_generator, verbose=0)
gat_model.save_weights(weightsfile)
print('--'*38 + f'\nTest loss: {test_loss:.4f}, accuracy: {test_accuracy*100:.3f}%, F1 score: {test_f1*100:.3f}%')