Skip to content

Commit

Permalink
fixed python module to check the id range.
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 10, 2021
1 parent 3589bfb commit 0e6dfbf
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 131 deletions.
18 changes: 10 additions & 8 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ def SampleEncodeAsIds(self, input, nbest_size, alpha):
def DecodePieces(self, pieces):
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)

def DecodeIds(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)

def EncodeAsSerializedProto(self, input):
return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input)

Expand All @@ -131,9 +128,6 @@ def NBestEncodeAsSerializedProto(self, input, nbest_size):
def DecodePiecesAsSerializedProto(self, pieces):
return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces)

def DecodeIdsAsSerializedProto(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids)

def GetPieceSize(self):
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)

Expand Down Expand Up @@ -176,6 +170,12 @@ def serialized_model_proto(self):
def LoadFromFile(self, arg):
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)

def DecodeIdsWithCheck(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids)

def DecodeIdsAsSerializedProtoWithCheck(self, ids):
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids)

def Init(self,
model_file=None,
model_proto=None,
Expand Down Expand Up @@ -310,15 +310,15 @@ def Decode(self, input):
if not input:
return self.DecodeIds([])
elif type(input) is int:
return self.DecodeIds([input])
return self.DecodeIdsWithCheck([input])
elif type(input) is str:
return self.DecodePieces([input])

def _decode(input):
if not input:
return self.DecodeIds([])
if type(input[0]) is int:
return self.DecodeIds(input)
return self.DecodeIdsWithCheck(input)
return self.DecodePieces(input)

if type(input[0]) is list:
Expand Down Expand Up @@ -508,6 +508,8 @@ def _batched_func(self, arg):

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck

for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
Expand Down
30 changes: 28 additions & 2 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
%ignore sentencepiece::SentencePieceProcessor::SampleEncode;
%ignore sentencepiece::SentencePieceProcessor::NBestEncode;
%ignore sentencepiece::SentencePieceProcessor::Decode;
%ignore sentencepiece::SentencePieceProcessor::DecodeIds;
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::model_proto;
%ignore sentencepiece::SentencePieceProcessor::Load;
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
Expand All @@ -196,6 +198,28 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
return $self->Load(arg);
}

std::string DecodeIdsWithCheck(
const std::vector<int> &ids) const {
const int num_pieces = $self->GetPieceSize();
for (int id : ids)
if (id < 0 || id >= num_pieces)
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kOutOfRange,
"piece id is out of range.");
return $self->DecodeIds(ids);
}

util::bytes DecodeIdsAsSerializedProtoWithCheck(
const std::vector<int> &ids) const {
const int num_pieces = $self->GetPieceSize();
for (int id : ids)
if (id < 0 || id >= num_pieces)
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kOutOfRange,
"piece id is out of range.");
return $self->DecodeIdsAsSerializedProto(ids);
}

%pythoncode {
def Init(self,
model_file=None,
Expand Down Expand Up @@ -331,15 +355,15 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
if not input:
return self.DecodeIds([])
elif type(input) is int:
return self.DecodeIds([input])
return self.DecodeIdsWithCheck([input])
elif type(input) is str:
return self.DecodePieces([input])

def _decode(input):
if not input:
return self.DecodeIds([])
if type(input[0]) is int:
return self.DecodeIds(input)
return self.DecodeIdsWithCheck(input)
return self.DecodePieces(input)

if type(input[0]) is list:
Expand Down Expand Up @@ -707,6 +731,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)

SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck

for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
Expand Down
Loading

0 comments on commit 0e6dfbf

Please # to comment.