Skip to content

Commit

Permalink
Usability improvement bi lstm sort (apache#8944)
Browse files Browse the repository at this point in the history
* Improve usability for the bilstm example

* Remove argparse from infer_sort since it changes existing usage
  • Loading branch information
anirudh2290 authored and zheng-da committed Jun 28, 2018
1 parent d2d3d6a commit c0bfb71
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 45 deletions.
48 changes: 22 additions & 26 deletions example/bi-lstm-sort/README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
This is an example of using bidirection lstm to sort an array.

Firstly, generate data by:

python gen_data.py

Move generated txt files to data directory

mkdir data
mv *.txt data

Then, train the model by:

python lstm_sort.py

At last, test model by:

python infer_sort.py 234 189 785 763 231

and will output sorted seq

189
231
234
763
785


Run the training script by doing the following:

```
python lstm_sort.py --start-range 100 --end-range 1000 --cpu
```
You can provide the start-range and end-range for the numbers and whether to train on the cpu or not.
By default the script tries to train on the GPU. The default start-range is 100 and end-range is 1000.

At last, test model by doing the following:

```
python infer_sort.py 234 189 785 763 231
```

This should output the sorted seq like the following:
```
189
231
234
763
785
```
25 changes: 19 additions & 6 deletions example/bi-lstm-sort/infer_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,29 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import os
import argparse
import numpy as np
import mxnet as mx

from sort_io import BucketSentenceIter, default_build_vocab
from rnn_model import BiLSTMInferenceModel

TRAIN_FILE = "sort.train.txt"
TEST_FILE = "sort.test.txt"
VALID_FILE = "sort.valid.txt"
DATA_DIR = os.path.join(os.getcwd(), "data")
SEQ_LEN = 5

def MakeInput(char, vocab, arr):
idx = vocab[char]
tmp = np.zeros((1,))
tmp[0] = idx
arr[:] = tmp

if __name__ == '__main__':
def main():
tks = sys.argv[1:]
assert len(tks) >= 5, "Please provide 5 numbers for sorting as sequence length is 5"
batch_size = 1
buckets = []
num_hidden = 300
Expand All @@ -42,20 +51,21 @@ def MakeInput(char, vocab, arr):
learning_rate = 0.1
momentum = 0.9

contexts = [mx.context.gpu(i) for i in range(1)]
contexts = [mx.context.cpu(i) for i in range(1)]

vocab = default_build_vocab("./data/sort.train.txt")
vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))
rvocab = {}
for k, v in vocab.items():
rvocab[v] = k

_, arg_params, __ = mx.model.load_checkpoint("sort", 1)
for tk in tks:
assert (tk in vocab), "{} not in range of numbers that the model trained for.".format(tk)

model = BiLSTMInferenceModel(5, len(vocab),
model = BiLSTMInferenceModel(SEQ_LEN, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0)

tks = sys.argv[1:]
data = np.zeros((1, len(tks)))
for k in range(len(tks)):
data[0][k] = vocab[tks[k]]
Expand All @@ -65,3 +75,6 @@ def MakeInput(char, vocab, arr):
for k in range(len(tks)):
print(rvocab[np.argmax(prob, axis = 1)[k]])


if __name__ == '__main__':
sys.exit(main())
1 change: 0 additions & 1 deletion example/bi-lstm-sort/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# pylint:skip-file
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
from collections import namedtuple
Expand Down
75 changes: 65 additions & 10 deletions example/bi-lstm-sort/lstm_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,75 @@

# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import os
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx
import random
import argparse

from lstm import bi_lstm_unroll
from sort_io import BucketSentenceIter, default_build_vocab

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)


TRAIN_FILE = "sort.train.txt"
TEST_FILE = "sort.test.txt"
VALID_FILE = "sort.valid.txt"
DATA_DIR = os.path.join(os.getcwd(), "data")
SEQ_LEN = 5

def gen_data(seq_len, start_range, end_range):
if not os.path.exists(DATA_DIR):
try:
logging.info('create directory %s', DATA_DIR)
os.makedirs(DATA_DIR)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise OSError('failed to create ' + DATA_DIR)
vocab = [str(x) for x in range(start_range, end_range)]
sw_train = open(os.path.join(DATA_DIR, TRAIN_FILE), "w")
sw_test = open(os.path.join(DATA_DIR, TEST_FILE), "w")
sw_valid = open(os.path.join(DATA_DIR, VALID_FILE), "w")

for i in range(1000000):
seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(seq_len)])
k = i % 50
if k == 0:
sw_test.write(seq + "\n")
elif k == 1:
sw_valid.write(seq + "\n")
else:
sw_train.write(seq + "\n")

sw_train.close()
sw_test.close()

def parse_args():
parser = argparse.ArgumentParser(description="Parse args for lstm_sort example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--start-range', type=int, default=100,
help='starting number of the range')
parser.add_argument('--end-range', type=int, default=1000,
help='Ending number of the range')
parser.add_argument('--cpu', action='store_true',
help='To use CPU for training')
return parser.parse_args()


def Perplexity(label, pred):
label = label.T.reshape((-1,))
loss = 0.
for i in range(pred.shape[0]):
loss += -np.log(max(1e-10, pred[i][int(label[i])]))
return np.exp(loss / label.size)

if __name__ == '__main__':
def main():
args = parse_args()
gen_data(SEQ_LEN, args.start_range, args.end_range)
batch_size = 100
buckets = []
num_hidden = 300
Expand All @@ -43,9 +96,12 @@ def Perplexity(label, pred):
learning_rate = 0.1
momentum = 0.9

contexts = [mx.context.gpu(i) for i in range(1)]
if args.cpu:
contexts = [mx.context.cpu(i) for i in range(1)]
else:
contexts = [mx.context.gpu(i) for i in range(1)]

vocab = default_build_vocab("./data/sort.train.txt")
vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE))

def sym_gen(seq_len):
return bi_lstm_unroll(seq_len, len(vocab),
Expand All @@ -56,9 +112,9 @@ def sym_gen(seq_len):
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h

data_train = BucketSentenceIter("./data/sort.train.txt", vocab,
data_train = BucketSentenceIter(os.path.join(DATA_DIR, TRAIN_FILE), vocab,
buckets, batch_size, init_states)
data_val = BucketSentenceIter("./data/sort.valid.txt", vocab,
data_val = BucketSentenceIter(os.path.join(DATA_DIR, VALID_FILE), vocab,
buckets, batch_size, init_states)

if len(buckets) == 1:
Expand All @@ -74,12 +130,11 @@ def sym_gen(seq_len):
wd=0.00001,
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

model.fit(X=data_train, eval_data=data_val,
eval_metric = mx.metric.np(Perplexity),
batch_end_callback=mx.callback.Speedometer(batch_size, 50),)

model.save("sort")

if __name__ == '__main__':
sys.exit(main())
1 change: 0 additions & 1 deletion example/bi-lstm-sort/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

Expand Down
1 change: 0 additions & 1 deletion example/bi-lstm-sort/sort_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# pylint: disable=superfluous-parens, no-member, invalid-name
from __future__ import print_function
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

Expand Down

0 comments on commit c0bfb71

Please # to comment.