-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathNMT.java
155 lines (137 loc) · 6.03 KB
/
NMT.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.util.DownloadUtils;
import ai.djl.util.Pair;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.stream.IntStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
public class NMT {
public static String readDataNMT() throws IOException {
DownloadUtils.download(
"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip", "fra-eng.zip");
ZipFile zipFile = new ZipFile(new File("fra-eng.zip"));
Enumeration<? extends ZipEntry> entries = zipFile.entries();
while (entries.hasMoreElements()) {
ZipEntry entry = entries.nextElement();
if (entry.getName().contains("fra.txt")) {
InputStream stream = zipFile.getInputStream(entry);
return new String(stream.readAllBytes(), StandardCharsets.UTF_8);
}
}
return null;
}
public static String preprocessNMT(String text) {
// Replace non-breaking space with space, and convert uppercase letters to
// lowercase ones
text = text.replace('\u202f', ' ').replaceAll("\\xa0", " ").toLowerCase();
// Insert space between words and punctuation marks
StringBuilder out = new StringBuilder();
Character currChar;
for (int i = 0; i < text.length(); i++) {
currChar = text.charAt(i);
if (i > 0 && noSpace(currChar, text.charAt(i - 1))) {
out.append(' ');
}
out.append(currChar);
}
return out.toString();
}
public static boolean noSpace(Character currChar, Character prevChar) {
/* Preprocess the English-French dataset. */
return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar)
&& prevChar != ' ';
}
public static Pair<ArrayList<String[]>, ArrayList<String[]>> tokenizeNMT(
String text, Integer numExamples) {
ArrayList<String[]> source = new ArrayList<>();
ArrayList<String[]> target = new ArrayList<>();
int i = 0;
for (String line : text.split("\n")) {
if (numExamples != null && i > numExamples) {
break;
}
String[] parts = line.split("\t");
if (parts.length == 2) {
source.add(parts[0].split(" "));
target.add(parts[1].split(" "));
}
i += 1;
}
return new Pair<>(source, target);
}
public static int[] truncatePad(Integer[] integerLine, int numSteps, int paddingToken) {
/* Truncate or pad sequences */
int[] line = Arrays.stream(integerLine).mapToInt(i -> i).toArray();
if (line.length > numSteps) {
return Arrays.copyOfRange(line, 0, numSteps);
}
int[] paddingTokenArr = new int[numSteps - line.length]; // Pad
Arrays.fill(paddingTokenArr, paddingToken);
return IntStream.concat(Arrays.stream(line), Arrays.stream(paddingTokenArr)).toArray();
}
public static Pair<NDArray, NDArray> buildArrayNMT(
ArrayList<String[]> lines, Vocab vocab, int numSteps, NDManager manager) {
/* Transform text sequences of machine translation into minibatches. */
ArrayList<Integer[]> linesIntArr = new ArrayList<>();
for (String[] strings : lines) {
linesIntArr.add(vocab.getIdxs(strings));
}
for (int i = 0; i < linesIntArr.size(); i++) {
ArrayList<Integer> temp = new ArrayList<>(Arrays.asList(linesIntArr.get(i)));
temp.add(vocab.getIdx("<eos>"));
linesIntArr.set(i, temp.toArray(new Integer[0]));
}
NDArray arr = manager.create(new Shape(linesIntArr.size(), numSteps), DataType.INT32);
int row = 0;
for (Integer[] line : linesIntArr) {
NDArray rowArr = manager.create(truncatePad(line, numSteps, vocab.getIdx("<pad>")));
arr.set(new NDIndex("{}:", row), rowArr);
row += 1;
}
NDArray validLen = arr.neq(vocab.getIdx("<pad>")).sum(new int[] {1});
return new Pair<>(arr, validLen);
}
public static Pair<ArrayDataset, Pair<Vocab, Vocab>> loadDataNMT(
int batchSize, int numSteps, int numExamples, NDManager manager) throws IOException {
/* Return the iterator and the vocabularies of the translation dataset. */
String text = preprocessNMT(readDataNMT());
Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text, numExamples);
ArrayList<String[]> source = pair.getKey();
ArrayList<String[]> target = pair.getValue();
Vocab srcVocab =
new Vocab(
source.toArray(new String[0][]),
2,
new String[] {"<pad>", "<bos>", "<eos>"});
Vocab tgtVocab =
new Vocab(
target.toArray(new String[0][]),
2,
new String[] {"<pad>", "<bos>", "<eos>"});
Pair<NDArray, NDArray> pairArr = buildArrayNMT(source, srcVocab, numSteps, manager);
NDArray srcArr = pairArr.getKey();
NDArray srcValidLen = pairArr.getValue();
pairArr = buildArrayNMT(target, tgtVocab, numSteps, manager);
NDArray tgtArr = pairArr.getKey();
NDArray tgtValidLen = pairArr.getValue();
ArrayDataset dataset =
new ArrayDataset.Builder()
.setData(srcArr, srcValidLen)
.optLabels(tgtArr, tgtValidLen)
.setSampling(batchSize, true)
.build();
return new Pair<>(dataset, new Pair<>(srcVocab, tgtVocab));
}
}