-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_training.py
42 lines (30 loc) · 912 Bytes
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
import os
import numpy as np
sys.path.append('./model')
sys.path.append('./data')
sys.path.append('./embed')
import data
import embed
#from rnn_model import RNNModel
from sep_cnn_model import SepCNNModel
TOP_K = 20000
def run_training():
texts, labels = data.load_train_data()
# create empedding
input, word_index = embed.sequence_vectorize(texts)
labels = np.array(labels)
# create model
model = SepCNNModel()
# pipeline
num_features = min(len(word_index) + 1, TOP_K)
embedding_matrix = embed.get_embedding_matrix(word_index, embedding_dim=200)
model.build(num_features, input.shape,
use_pretrained_embedding=True,
is_embedding_trainable=False,
embedding_matrix=embedding_matrix)
model.fit(input, labels, epochs=10)
model.save(f'saved_models/{type(model).__name__}')
# Predict
if __name__ == '__main__':
run_training()