-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
29 lines (28 loc) · 1.26 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
class LSTM_Net(nn.Module):
def __init__(self, embedding, embedding_dim, hidden_dim, output_dim, num_layers, dropout=0.2, fix_embedding=True):
super(LSTM_Net, self).__init__()
# 製作 embedding layer
self.embedding = torch.nn.Embedding(embedding.size(0),embedding.size(1))
self.embedding.weight = torch.nn.Parameter(embedding)
# 是否將 embedding fix 住,如果 fix_embedding 為 False,在訓練過程中,embedding 也會跟著被訓練
self.embedding.weight.requires_grad = False if fix_embedding else True
self.embedding_dim = embedding.size(1)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
nn.GELU()
)
def forward(self, inputs):
inputs = self.embedding(inputs)
x, _ = self.lstm(inputs, None)
# x 的 dimension (batch, seq_len, hidden_size)
# 取用 LSTM 最後一層的 hidden state
x = x[:, -1, :]
x = self.classifier(x)
return x