-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
65 lines (48 loc) · 1.63 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from torch import nn
from torch.utils.data import Dataset
import h5py
class EntDataset(Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.feat_dim = 228
hf = h5py.File(self.data_path, 'r')
self.len = hf['dataset'].shape[0]
hf.close()
def __getitem__(self, idx):
hf = h5py.File(self.data_path, 'r')
X = hf['dataset'][idx, 0:self.feat_dim]
Y = hf['dataset'][idx, self.feat_dim:2*self.feat_dim]
hf.close()
return (X, Y)
def __len__(self):
return self.len
class NED(nn.Module):
def __init__(self, feat_dim=228, int_dim=128, latent_dim=30):
super(NED, self).__init__()
self.feat_dim = feat_dim
self.fc1 = nn.Linear(feat_dim, int_dim)
self.bn1 = nn.BatchNorm1d(int_dim)
self.fc2 = nn.Linear(int_dim, latent_dim)
self.bn2 = nn.BatchNorm1d(latent_dim)
self.fc3 = nn.Linear(latent_dim, int_dim)
self.bn3 = nn.BatchNorm1d(int_dim)
self.fc4 = nn.Linear(int_dim, feat_dim)
self.bn4 = nn.BatchNorm1d(feat_dim)
self.relu = nn.ReLU()
def encode(self, x):
h1 = self.fc1(x.float())
h1 = self.relu(self.bn1(h1))
h2 = self.fc2(h1)
z = self.bn2(h2)
return z
def decode(self, z):
h3 = self.fc3(z)
h3 = self.relu(self.bn3(h3))
h4 = self.fc4(h3)
recon = self.bn4(h4)
return recon
def forward(self, x):
z = self.encode(x.view(-1, self.feat_dim))
return self.decode(z)
def embedding(self, x):
return self.encode(x.view(-1, self.feat_dim))