Skip to content

Commit

Permalink
checks the range of id in Decode method
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 9, 2021
1 parent 0e03b57 commit 8083d4f
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 213 deletions.
14 changes: 3 additions & 11 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was automatically generated by SWIG (http://www.swig.org).
# Version 4.0.1
# Version 4.0.2
#
# Do not make changes to this file unless you know what you are doing--modify
# the SWIG interface file instead.
Expand Down Expand Up @@ -170,12 +170,6 @@ def serialized_model_proto(self):
def LoadFromFile(self, arg):
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)

def DecodeIdsWithCheck(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids)

def DecodeIdsAsSerializedProtoWithCheck(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids)

def Init(self,
model_file=None,
model_proto=None,
Expand Down Expand Up @@ -310,15 +304,15 @@ def Decode(self, input):
if not input:
return self.DecodeIds([])
elif type(input) is int:
return self.DecodeIdsWithCheck([input])
return self.DecodeIds([input])
elif type(input) is str:
return self.DecodePieces([input])

def _decode(input):
if not input:
return self.DecodeIds([])
if type(input[0]) is int:
return self.DecodeIdsWithCheck(input)
return self.DecodeId(input)
return self.DecodePieces(input)

if type(input[0]) is list:
Expand Down Expand Up @@ -508,8 +502,6 @@ def _batched_func(self, arg):

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck

for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
Expand Down
26 changes: 2 additions & 24 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -198,26 +198,6 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
return $self->Load(arg);
}

std::string DecodeIdsWithCheck(
const std::vector<int> &ids) const {
for (int id : ids)
if (id < 0 || id >= $self->GetPieceSize())
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kOutOfRange,
"piece id is out of range.");
return $self->DecodeIds(ids);
}

util::bytes DecodeIdsAsSerializedProtoWithCheck(
const std::vector<int> &ids) const {
for (int id : ids)
if (id < 0 || id >= $self->GetPieceSize())
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kOutOfRange,
"piece id is out of range.");
return $self->DecodeIdsAsSerializedProto(ids);
}

%pythoncode {
def Init(self,
model_file=None,
Expand Down Expand Up @@ -353,15 +333,15 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
if not input:
return self.DecodeIds([])
elif type(input) is int:
return self.DecodeIdsWithCheck([input])
return self.DecodeIds([input])
elif type(input) is str:
return self.DecodePieces([input])

def _decode(input):
if not input:
return self.DecodeIds([])
if type(input[0]) is int:
return self.DecodeIdsWithCheck(input)
return self.DecodeId(input)
return self.DecodePieces(input)

if type(input[0]) is list:
Expand Down Expand Up @@ -729,8 +709,6 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck

for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
Expand Down
Loading

0 comments on commit 8083d4f

Please # to comment.