-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathinfer.py
34 lines (31 loc) · 1.14 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tensorflow as tf
tf.app.flags.DEFINE_string("model_type", "charcnn", "默认为cnn")
tf.app.flags.DEFINE_string("sentence", "微信可以登录吗", "默认为cnn")
FLAGS = tf.app.flags.FLAGS
model_type = FLAGS.model_type
sentences = FLAGS.sentence
infer = None
if model_type == 'textcnn':
import model.textcnn.Infer as textcnn_infer
infer = textcnn_infer.Infer()
elif model_type == 'charcnn':
import model.char_cnn.Infer as char_cnn_infer
infer = char_cnn_infer.Infer()
elif model_type == 'fasttext':
import model.fast_text.Infer as fasttext_infer
infer = fasttext_infer.Infer()
elif model_type == 'textrnn':
import model.textrnn.Infer as textrnn_infer
infer = textrnn_infer.Infer()
elif model_type == 'birnn_attention':
import model.birnn_attention.Infer as birnn_attention_infer
infer = birnn_attention_infer.Infer()
elif model_type == 'leam':
import model.leam.Infer as leam_infer
infer = leam_infer.Infer()
elif model_type == 'transformer':
import model.transformer.Infer as transformer_infer
infer = transformer_infer.Infer()
else:
print("do not exist this model")
print(infer.infer([sentences]))