-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
41 lines (29 loc) · 1.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from config.reader import ConfigReader
from config.utils import print_config
from network.network import Network
from network.loss.visualizer import plot_loss_per_minibatch
from data.generate_dataset import generate_dataset
# READ CONFIGURATION
config_reader = ConfigReader(filepath='config/config.ini')
config = config_reader.get_data()
print('\nConfiguration summary:\n----------------------')
print_config(config)
# DATA
(train_data, train_targets), (val_data, val_targets), (test_data, test_targets) = generate_dataset(config, save_examples=True)
# build model
model = Network(config['loss'], config['layers'], config['wreg'], config['wrt'])
# train model
print('\nTraining model...')
train_loss_history, val_loss_history = model.fit(train_data, train_targets,
val_data, val_targets,
batch_size=config['batch_size'],
epochs=config['epochs'],
verbose=config['verbose'])
# visualize kernels if any
model.visualize_kernels(save_fig=True)
# test model
print('Done training. Testing...')
test_loss_history, _ = model.predict(test_data, test_targets)
print('Done testing. Testing error:', test_loss_history[0])
# visualize learning
plot_loss_per_minibatch(config['loss'], train_loss_history, val_loss_history, test_loss_history, save_fig=True)