-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
64 lines (53 loc) · 2.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy
import scipy.ndimage
from tqdm import tqdm
from neuralnetwork import neuralNetwork
# 入力層、隠れ層、出力層のノード数
input_nodes = 784
hidden_nodes = 100
output_nodes = 10
hiddenlayers = 3
# 学習率 = 0.3
learning_rate = 0.01
# ニューラルネットワークのインスタンス生成
n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate, hiddenlayers)
# MNIST 訓練データのCSV ファイルを読み込んでリストにする
training_data_file = open("dataSet/mnist_train.csv", 'r')
training_data_list = training_data_file.readlines()
training_data_file.close()
epochs = 1
training_count = len(training_data_list) * epochs
training_bar = tqdm(total = training_count)
training_bar.set_description('now training... ')
for e in range(epochs):
# 訓練データの全データに対して実行
for recode in training_data_list:
# データを','でsplit
all_values = recode.split(',')
# 入力値のスケーリングとシフト
inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
# 目標配列の生成(ラベル位置が0.99 残りは0.01)
targets = numpy.zeros(output_nodes) + 0.01
# all_values[0]はこのデータのラベル
targets[int(all_values[0])] = 0.99
n.train(inputs,targets)
## create rotated variations
# rotated anticlockwise by x degrees
inputs_plusx_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), 10, cval=0.01, order=1, reshape=False)
n.train(inputs_plusx_img.reshape(784), targets)
# rotated clockwise by x degrees
inputs_minusx_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), -10, cval=0.01, order=1, reshape=False)
n.train(inputs_minusx_img.reshape(784), targets)
# rotated anticlockwise by 10 degrees
#inputs_plus10_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), 10, cval=0.01, order=1, reshape=False)
#n.train(inputs_plus10_img.reshape(784), targets)
# rotated clockwise by 10 degrees
#inputs_minus10_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), -10, cval=0.01, order=1, reshape=False)
#n.train(inputs_minus10_img.reshape(784), targets)
#progress training_bar
training_bar.update(1)
pass
print(n.loss/len(training_data_list))
pass
training_bar.close()
n.save_parameters()