forked from sander-ali/Visual_Transformer_code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvtransformer_code.py
81 lines (62 loc) · 2.73 KB
/
vtransformer_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#You need to install the following python packages
#pytorch, vit_pytorch.
import torch
import torchvision
from vit_pytorch import ViT
import time
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(97)
Dpath = '/data/mnist'
Bs_Train = 100
Bs_Test = 1000
tform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,))])
tr_set = torchvision.datasets.MNIST(Dpath, train = True, download = True,
transform = tform_mnist)
tr_load = torch.utils.data.DataLoader(tr_set, batch_size = Bs_Train, shuffle = True)
ts_set = torchvision.datasets.MNIST(Dpath, train = False, download = True, transform = tform_mnist)
ts_load = torch.utils.data.DataLoader(ts_set, batch_size = Bs_Test, shuffle = True)
def train_iter(model, optimz, data_load, loss_val):
samples = len(data_load.dataset)
model.train()
for i, (data, target) in enumerate(data_load):
optimz.zero_grad()
out = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(out, target)
loss.backward()
optimz.step()
if i % 100 == 0:
print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +
' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)] Loss: ' +
'{:6.4f}'.format(loss.item()))
loss_val.append(loss.item())
def evaluate(model, data_load, loss_val):
model.eval()
samples = len(data_load.dataset)
csamp = 0
tloss = 0
with torch.no_grad():
for data, target in data_load:
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target, reduction='sum')
_, pred = torch.max(output, dim=1)
tloss += loss.item()
csamp += pred.eq(target).sum()
aloss = tloss / samples
loss_val.append(aloss)
print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
' Accuracy:' + '{:5}'.format(csamp) + '/' +
'{:5}'.format(samples) + ' (' +
'{:4.2f}'.format(100.0 * csamp / samples) + '%)\n')
N_EPOCHS = 25
start_time = time.time()
model = ViT(image_size=28, patch_size=4, num_classes=10, channels=1,
dim=64, depth=6, heads=8, mlp_dim=128)
optimz = optim.Adam(model.parameters(), lr=0.003)
trloss_val, tsloss_val = [], []
for epoch in range(1, N_EPOCHS + 1):
print('Epoch:', epoch)
train_iter(model, optimz, tr_load, trloss_val)
evaluate(model, ts_load, tsloss_val)
print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')