forked from vageesh1/GPT-2-From-Scratch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpt2.py
220 lines (194 loc) · 8.65 KB
/
gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import ModuleList
from torch.nn.modules.normalization import LayerNorm
from torch import nn, einsum, broadcast_tensors
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset
from einops import rearrange, repeat
from einops import rearrange, repeat, pack, unpack
import os
import copy
import glob
import shutil
from math import pi, log
import math
import time
from dataclasses import dataclass
from typing import Optional, Union
import logging
from tqdm import tqdm
from transformers import GPT2Tokenizer
from einops import rearrange, repeat
from einops import rearrange, repeat, pack, unpack
import os
import copy
import glob
import shutil
from math import pi, log
import math
import time
from dataclasses import dataclass
from typing import Optional, Union
import logging
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _get_clones(module, n):
'''Here we can make certain copies of transformers'''
return ModuleList([copy.deepcopy(module) for i in range(n)])
#conv1d layer
class Conv1D(nn.Module):
def __init__(self, nx, nf):
'''
nx: Number of input features.
nf: Number of filters (output channels).
'''
super().__init__()
self.nf = nf
#intialising an empty matrix as weights for size of (nx)X(nf)
w = torch.empty(nx, nf)
#initialising these weights as normal distribution
nn.init.normal_(w, std=0.02)
#calculating the weights and biases by encoding them using nn.Parameter
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
'''x:The input tensor'''
#this size output is summation of x second dimension and the nf dimension
size_out = x.size()[:-1] + (self.nf,)
# dot multiplying Q,K(transpose) and V
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)#x.view helps in taking the transpose out
x = x.view(*size_out)
return x
#feed forward layer
class FeedForward(nn.Module):
def __init__(self, dropout, d_model=768, nx=768*4):
super().__init__()
self.c_fc = Conv1D(d_model, nx)
self.c_proj = Conv1D(nx, d_model)
self.act = F.gelu
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.c_proj(self.act(self.c_fc(x))))
#attention layer
class Attention(nn.Module):
def __init__(self, d_model=768, n_head=12, n_ctx=1024, d_head=64, bias=True, scale=False):
'''Constructor funtion
Params:
d_model:The dimension that needs to be feed into our model
n_head:The number of heads for attention
n_ctx:a parameters for buffer registry for bias
d_head:the dimension head output
bias:A bool for including or not including bias
scale: Whether to scale the attention scores by the square root of the dimension of the queries(use sqrt(dk) or not) "
'''
super().__init__()
self.n_head = n_head
self.d_model = d_model
self.c_attn = Conv1D(d_model, d_model*3)
self.scale = scale
self.softmax = nn.Softmax(dim=-1)
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.dropout = nn.Dropout(0.1)
self.c_proj = Conv1D(d_model, d_model)
def split_heads(self, x):
"""
spliting inyo given number of heads and then returning
return shape [`batch`, `head`, `sequence`, `features`]
"""
new_shape = x.size()[:-1] + (self.n_head, x.size(-1)//self.n_head)
x = x.view(*new_shape)
return x.permute(0, 2, 1, 3)
def _attn(self, q, k, v, attn_mask=None):
"""The main attention function
That claculates the attention using our dot product formula"""
scores = torch.matmul(q, k.transpose(-2, -1))# dot multiplication between q and k transpose
if self.scale: scores = scores/math.sqrt(v.size(-1))# scaling it by dividing by sqrt(dk)
nd, ns = scores.size(-2), scores.size(-1)
if attn_mask is not None: scores = scores + attn_mask# adding scores with attention mask values
scores = self.softmax(scores)# adding softmax values
scores = self.dropout(scores) #dropout of 0.1 as mentioned
outputs = torch.matmul(scores, v) # now the final matrix multiplication between score and V
return outputs
def merge_heads(self, x):
# merging the attention heads into one
x = x.permute(0, 2, 1, 3).contiguous()
new_shape = x.size()[:-2] + (x.size(-2)*x.size(-1),)
return x.view(*new_shape)
def forward(self, x):
'''The feed forward function that calculates the attention, split the heads, make attention, merge heads and project out the output '''
x = self.c_attn(x) #new `x` shape - `[1,3,2304]`
q, k, v = x.split(self.d_model, dim=2)
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
out = self._attn(q, k, v)
out = self.merge_heads(out)
out = self.c_proj(out)
return out
#transformer block
class TransformerBlock(nn.Module):
def __init__(self, d_model=768, n_head=12, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attn = Attention(d_model=768, n_head=12, d_head=64, n_ctx=1024, bias=True, scale=False)
self.feedforward = FeedForward(dropout=0.1, d_model=768, nx=768*4)
self.ln_1 = LayerNorm(d_model)
self.ln_2 = LayerNorm(d_model)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.feedforward(self.ln_2(x))
return x
class GPT2(nn.Module):
def __init__(self, nlayers=12, n_ctx=1024, d_model=768, vcb_sz=50257):
'''nlayer:The number of times the tarnsformer needs to get cloned
n_ctx:The highest length that can be these to get teh string positional embeddings
d_model:The dimenionalities for model
vcb_sz:The vocablury size which can be later altered while training'''
super(GPT2, self).__init__()
self.nlayers = nlayers
block = TransformerBlock(d_model=768, n_head=12, dropout=0.1)
self.h = _get_clones(block, 12)
self.wte = nn.Embedding(vcb_sz, d_model)
self.wpe = nn.Embedding(n_ctx, d_model)
self.drop = nn.Dropout(0.1)
self.ln_f = LayerNorm(d_model)
self.out = nn.Linear(d_model, vcb_sz, bias=False)
self.loss_fn = nn.CrossEntropyLoss()
self.init_weights()
def init_weights(self):
'''Initialization of weights'''
self.out.weight = self.wte.weight
self.apply(self._init_weights)
def _init_weights(self, module):
'''If the Linear, Embedding and Conv1D then nomrally initializing with mean and S.D'''
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
'''Data Bias zero'''
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, src, labels=None, pos_ids=None):
'''Adding the positional embedding, dropping, then adding inputs, logits and outputs which are being used for loss function and then adding outputs and loss'''
if pos_ids is None:
pos_ids = torch.arange(0, src.size(-1)).unsqueeze(0)
pos_ids = pos_ids.to(src.device) # Ensure pos_ids is on the same device as src
inp = self.drop((self.wte(src) + self.wpe(pos_ids)))
for i in range(self.nlayers): inp = self.h[i](inp)
inp = self.ln_f(inp)
logits = self.out(inp)
outputs = (logits,) + (inp,)
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return loss.mean()
return logits