-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
70 lines (51 loc) · 1.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Tuple
import torch
from torch import nn
from torch_geometric.data import Data
class Encoder(nn.Module):
def __init__(self, hidden_model, mean_model, std_model, use_edge_attr: bool = False):
super().__init__()
self.hidden_model = hidden_model
self.mean_model = mean_model
self.std_model = std_model
self.use_edge_attr = use_edge_attr
def encode(self,
x: torch.Tensor,
edge_index: torch.LongTensor,
edge_attr: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
hidden = self.hidden_model(x, edge_index, edge_attr)
mean = self.mean_model(hidden, edge_index, edge_attr)
std = self.std_model(hidden, edge_index, edge_attr)
return mean, std
def forward(self, data: Data):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
mu, logvar = self.encode(x, edge_index, edge_attr)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, activation=torch.sigmoid, dropout: float = 0.1):
super().__init__()
self.activation = activation
self.dropout = nn.Dropout(p=dropout)
def forward(self, z: torch.Tensor) -> torch.Tensor:
z = self.dropout(z)
adj_reconstructed = torch.matmul(z, z.T)
if self.training:
adj_reconstructed = self.activation(adj_reconstructed)
return adj_reconstructed
class VGAE(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def reparametrize(self, mu, logvar):
if self.training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std) + mu
else:
return mu
def forward(self, data: Data):
mu, logvar = self.encoder(data)
z = self.reparametrize(mu, logvar)
adj = self.decoder(z)
return adj, mu, logvar