Skip to content

Commit 33c2d60

Browse files
committed
[tokenizer] Return if exceed max token length (#2957)
1 parent 3e9da7d commit 33c2d60

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java

+12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public class Encoding {
2727
private long[] specialTokenMask;
2828
private CharSpan[] charTokenSpans;
2929
private Encoding[] overflowing;
30+
private boolean exceedMaxLength;
3031

3132
protected Encoding(
3233
long[] ids,
@@ -36,6 +37,7 @@ protected Encoding(
3637
long[] attentionMask,
3738
long[] specialTokenMask,
3839
CharSpan[] charTokenSpans,
40+
boolean exceedMaxLength,
3941
Encoding[] overflowing) {
4042
this.ids = ids;
4143
this.typeIds = typeIds;
@@ -44,6 +46,7 @@ protected Encoding(
4446
this.attentionMask = attentionMask;
4547
this.specialTokenMask = specialTokenMask;
4648
this.charTokenSpans = charTokenSpans;
49+
this.exceedMaxLength = exceedMaxLength;
4750
this.overflowing = overflowing;
4851
}
4952

@@ -127,6 +130,15 @@ public CharSpan[] getCharTokenSpans() {
127130
return charTokenSpans;
128131
}
129132

133+
/**
134+
* Returns if tokens exceed max length.
135+
*
136+
* @return {@code true} if tokens exceed max length
137+
*/
138+
public boolean exceedMaxLength() {
139+
return exceedMaxLength;
140+
}
141+
130142
/**
131143
* Returns an array of overflowing encodings.
132144
*

extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,10 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
530530
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
531531
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
532532

533+
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
534+
boolean exceedMaxLength = overflowingHandles.length > 0;
533535
Encoding[] overflowing;
534536
if (withOverflowingTokens) {
535-
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
536-
537537
overflowing = new Encoding[overflowingHandles.length];
538538
for (int i = 0; i < overflowingHandles.length; ++i) {
539539
overflowing[i] = toEncoding(overflowingHandles[i], true);
@@ -551,6 +551,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
551551
attentionMask,
552552
specialTokenMask,
553553
charSpans,
554+
exceedMaxLength,
554555
overflowing);
555556
}
556557

extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ public void testTruncationStride() throws IOException {
330330
.build()) {
331331
String text = "Hello there my friend I am happy to see you";
332332
String textPair = "How are you my friend";
333-
Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing();
333+
Encoding encoding = tokenizer.encode(text, textPair);
334+
Assert.assertTrue(encoding.exceedMaxLength());
335+
Encoding[] overflowing = encoding.getOverflowing();
334336

335337
int expectedNumberOfOverflowEncodings = 7;
336338
Assert.assertEquals(overflowing.length, expectedNumberOfOverflowEncodings);

0 commit comments

Comments
 (0)