-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrnn_name2lang.py
162 lines (142 loc) · 6 KB
/
rnn_name2lang.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/python
#-*- coding: utf-8 -*-
"""Name-to-Language Classification with a Character-level RNN
This example is modified and improved from the PyTorch official tutorial as follows:
- Link: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import unicodedata, os
import time, glob, random, sklearn
from dnn_iris2 import train
# Define hyperparameters
EPOCH_MAX = 20
EPOCH_LOG = 1
OPTIMIZER_PARAM = {'lr': 0.01}
SCHEDULER_PARAM = {'step_size': 5, 'gamma': 0.5}
DATA_PATH = './data/names/*.txt'
USE_CUDA = torch.cuda.is_available()
SAVE_MODEL = 'rnn_name2lang.pt' # Make empty('') if you don't want save the model
RANDOM_SEED = 777
LETTER_DICT = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'"
# Convert Unicode to ASCII
# e.g. Ślusàrski to Slusarski
def unicode2ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn' and c in LETTER_DICT)
# Read raw files which contain names belong to each language
# Note) Each filename is used as its target's name.
def load_name_dataset(files):
data = []
targets = []
target_names = []
for idx, filename in enumerate(files):
lang = os.path.splitext(os.path.basename(filename))[0]
names = open(filename, encoding='utf-8').read().strip().split('\n')
data += [unicode2ascii(name) for name in names]
targets += [idx] * len(names)
target_names.append(lang)
return data, targets, target_names
# Transform the given text to its one-hot encoded tensor
# Note) Tensor size: len(text) x 1 x len(LETTER_DICT)
# sequence_length x batch_size x input_size
def text2onehot(text, device='cpu'):
tensor = torch.zeros(len(text), 1, len(LETTER_DICT), device=device)
for idx, letter in enumerate(text):
tensor[idx][0][LETTER_DICT.find(letter)] = 1
return tensor
# A simple RNN model
# - Try a different RNN unit such LSTM and GRU
# - Try less or more hidden units
# - Try more layers (e.g. 'num_layers=2') and dropout (e.g. 'dropout=0.4')
class MyRNN(nn.Module):
def init(self, input_size, output_size):
super(MyRNN, self).init()
self.rnn = torch.nn.RNN(input_size, 128)
self.fc = torch.nn.Linear(128, output_size)
def forward(self, x):
output, hidden = self.rnn(x)
x = self.fc(output[-1]) # Use output of the last sequence
return x
# Predict the best result of the given text
def predict(text, model):
model.eval()
with torch.no_grad():
dev = next(model.parameters()).device
text_tensor = text2onehot(text, dev) # Convert text to one-hot vectors
output = model(text_tensor)[0] # Get the last output
lang = torch.argmax(output) # Get the best among 18 classes
return lang.item()
# Predict and report top-k results of the given text
def report_predict(text, model, target_names, n_predict=5):
print(f'* Name: {text}')
model.eval()
with torch.no_grad():
dev = next(model.parameters()).device
text_tensor = text2onehot(text, dev)
output = model(text_tensor)[0]
prob = nn.functional.softmax(output, dim=0) # Make output as probability
top_val, top_idx = prob.topk(n_predict) # Get top-k among 18 classes
for i in range(len(top_val)):
print(f' {i+1}. {target_names[top_idx[i].item()]:<10}: {top_val[i]*100:4.1f} %')
if __name__ == '__main__':
# 0. Preparation
torch.manual_seed(RANDOM_SEED)
if USE_CUDA:
torch.cuda.manual_seed_all(RANDOM_SEED)
dev = torch.device('cuda' if USE_CUDA else 'cpu')
# 1. Load the name2lang dataset
data, targets, target_names = load_name_dataset(glob.glob(DATA_PATH))
data_train = [(text2onehot(data[i], device=dev), torch.LongTensor(
[targets[i]]).to(dev)) for i in range(len(data))]
random.shuffle(data_train)
# 2. Instantiate a model, loss function, and optimizer
model = MyRNN(len(LETTER_DICT), len(target_names)).to(dev)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), **OPTIMIZER_PARAM)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **SCHEDULER_PARAM)
# 3.1. Train the model
loss_list = []
start = time.time()
for epoch in range(1, EPOCH_MAX + 1):
train_loss = train(model, data_train, loss_func, optimizer)
scheduler.step()
loss_list.append([epoch, train_loss])
if epoch % EPOCH_LOG == 0:
elapse = (time.time() - start) / 60
print(f'{epoch:>6} ({elapse:>6.2f} min), TrainLoss={train_loss:.6f}, lr={scheduler.get_last_lr()}')
elapse = (time.time() - start) / 60
# 3.2. Save the trained model if necessary
if SAVE_MODEL:
torch.save(model.state_dict(), SAVE_MODEL)
# 4.1. Visualize the loss curves
plt.title(f'Training Loss (time: {elapse:.2f} [min] @ CUDA: {USE_CUDA})')
loss_array = np.array(loss_list)
plt.plot(loss_array[:, 0], loss_array[:, 1], label='Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss values')
plt.xlim(loss_array[0, 0], loss_array[-1, 0])
plt.grid()
plt.legend()
plt.show()
# 4.2. Visualize the confusion matrix
predicts = [predict(datum, model) for datum in data]
conf_mat = sklearn.metrics.confusion_matrix(targets, predicts, normalize='true')
plt.imshow(conf_mat)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.gca().set_xticklabels([''] + target_names, rotation=90)
plt.gca().set_yticklabels([''] + target_names)
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(1))
plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
# 5. Test your texts
report_predict('Choi', model, target_names)
report_predict('Jane', model, target_names)
report_predict('Daniel', model, target_names)
report_predict('Chow', model, target_names)
report_predict('Tanaka', model, target_names)