-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdemo_train_utils.py
65 lines (53 loc) · 2.1 KB
/
demo_train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Utilities for training models in the demo"""
import torch
from dae import DAE, Naive_DAE
from rbm import RBM
from utils import *
def train_rbm(train_dl, visible_dim, hidden_dim, k, num_epochs, lr, use_gaussian=False):
"""Create and train an RBM
Uses a custom strategy to have 0.5 momentum before epoch 5 and 0.9 momentum after
Parameters
----------
train_dl: DataLoader
training data loader
visible_dim: int
number of dimensions in visible (input) layer
hidden_dim: int
number of dimensions in hidden layer
k: int
number of iterations to run for Gibbs sampling (often 1 is used)
num_epochs: int
number of epochs to run for
lr: float
learning rate
use_gaussian:
whether to use a Gaussian distribution for the hidden state
Returns
-------
RBM, Tensor, Tensor
a trained RBM model, sample input tensor, reconstructed activation probabilities for sample input tensor
"""
rbm = RBM(visible_dim=visible_dim, hidden_dim=hidden_dim, gaussian_hidden_distribution=use_gaussian)
loss = torch.nn.MSELoss() # we will use MSE loss
for epoch in range(num_epochs):
train_loss = 0
for i, data_list in enumerate(train_dl):
sample_data = data_list[0].to(DEVICE)
v0, pvk = sample_data, sample_data
# Gibbs sampling
for i in range(k):
_, hk = rbm.sample_h(pvk)
pvk = rbm.sample_v(hk)
# compute ph0 and phk for updating weights
ph0, _ = rbm.sample_h(v0)
phk, _ = rbm.sample_h(pvk)
# update weights
rbm.update_weights(v0, pvk, ph0, phk, lr,
momentum_coef=0.5 if epoch < 5 else 0.9,
weight_decay=2e-4,
batch_size=sample_data.shape[0])
# track loss
train_loss += loss(v0, pvk)
# print training loss
print(f"epoch {epoch}: {train_loss/len(train_dl)}")
return rbm, v0, pvk