-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranslate_kd.sh
61 lines (48 loc) · 1.18 KB
/
translate_kd.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/bash
set -e
model_root_dir=output
# set task
task=zh2en
# set tag
model_dir_tag=zh2en_baseline
# set device
gpu=0
cpu=
# data set
who=train
if [ $task == "zh2en" ]; then
data_dir=nist12.filtered.zh-en
ensemble=
batch_size=64
beam=6
length_penalty=1.0
src_lang=zh
tgt_lang=en
sacrebleu_set=
else
echo "unknown task=$task"
exit
fi
model_dir=$model_root_dir/$task/$model_dir_tag
checkpoint=checkpoint_best.pt
if [ -n "$ensemble" ]; then
if [ ! -e "$model_dir/last$ensemble.ensemble.pt" ]; then
PYTHONPATH=`pwd` python3 scripts/average_checkpoints.py --inputs $model_dir --output $model_dir/last$ensemble.ensemble.pt --num-epoch-checkpoints $ensemble
fi
checkpoint=last$ensemble.ensemble.pt
fi
output=$model_dir/translation.log
if [ -n "$cpu" ]; then
use_cpu=--cpu
fi
export CUDA_VISIBLE_DEVICES=$gpu
python3 generate.py \
../data-bin/$data_dir \
--path $4 \
--gen-subset $who \
--batch-size $batch_size \
--beam $beam \
--lenpen $length_penalty \
--quiet \
--output $model_dir/hypo.txt | tee $output
python3 rerank.py $model_dir/hypo.txt $model_dir/hypo.sorted