From 8083d4f6e76dba5250ee3ca15d0c325462b761d3 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Sat, 9 Jan 2021 16:58:41 +0900 Subject: [PATCH] checks the range of id in Decode method --- python/src/sentencepiece/__init__.py | 14 +- python/src/sentencepiece/sentencepiece.i | 26 +- .../src/sentencepiece/sentencepiece_wrap.cxx | 250 ++++++------------ src/sentencepiece_processor.cc | 8 +- src/sentencepiece_processor_test.cc | 12 +- 5 files changed, 97 insertions(+), 213 deletions(-) diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 566f810a..b5fe7c4c 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -1,5 +1,5 @@ # This file was automatically generated by SWIG (http://www.swig.org). -# Version 4.0.1 +# Version 4.0.2 # # Do not make changes to this file unless you know what you are doing--modify # the SWIG interface file instead. @@ -170,12 +170,6 @@ def serialized_model_proto(self): def LoadFromFile(self, arg): return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg) - def DecodeIdsWithCheck(self, ids): - return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids) - - def DecodeIdsAsSerializedProtoWithCheck(self, ids): - return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids) - def Init(self, model_file=None, model_proto=None, @@ -310,7 +304,7 @@ def Decode(self, input): if not input: return self.DecodeIds([]) elif type(input) is int: - return self.DecodeIdsWithCheck([input]) + return self.DecodeIds([input]) elif type(input) is str: return self.DecodePieces([input]) @@ -318,7 +312,7 @@ def _decode(input): if not input: return self.DecodeIds([]) if type(input[0]) is int: - return self.DecodeIdsWithCheck(input) + return self.DecodeId(input) return self.DecodePieces(input) if type(input[0]) is list: @@ -508,8 +502,6 @@ def _batched_func(self, arg): SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode -SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck -SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 40938e44..ef37ff9e 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -198,26 +198,6 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { return $self->Load(arg); } - std::string DecodeIdsWithCheck( - const std::vector &ids) const { - for (int id : ids) - if (id < 0 || id >= $self->GetPieceSize()) - throw sentencepiece::util::Status( - sentencepiece::util::StatusCode::kOutOfRange, - "piece id is out of range."); - return $self->DecodeIds(ids); - } - - util::bytes DecodeIdsAsSerializedProtoWithCheck( - const std::vector &ids) const { - for (int id : ids) - if (id < 0 || id >= $self->GetPieceSize()) - throw sentencepiece::util::Status( - sentencepiece::util::StatusCode::kOutOfRange, - "piece id is out of range."); - return $self->DecodeIdsAsSerializedProto(ids); - } - %pythoncode { def Init(self, model_file=None, @@ -353,7 +333,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { if not input: return self.DecodeIds([]) elif type(input) is int: - return self.DecodeIdsWithCheck([input]) + return self.DecodeIds([input]) elif type(input) is str: return self.DecodePieces([input]) @@ -361,7 +341,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { if not input: return self.DecodeIds([]) if type(input[0]) is int: - return self.DecodeIdsWithCheck(input) + return self.DecodeId(input) return self.DecodePieces(input) if type(input[0]) is list: @@ -729,8 +709,6 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init) SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode -SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck -SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index a358b393..7451d604 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- * This file was automatically generated by SWIG (http://www.swig.org). - * Version 4.0.1 + * Version 4.0.2 * * This file is not intended to be easily readable and contains a number of * coding conventions designed to improve portability and efficiency. Do not make @@ -808,15 +808,19 @@ SWIG_UnpackDataName(const char *c, void *ptr, size_t sz, const char *name) { SWIGINTERN char* SWIG_Python_str_AsChar(PyObject *str) { -#if PY_VERSION_HEX >= 0x03000000 +#if PY_VERSION_HEX >= 0x03030000 + return (char *)PyUnicode_AsUTF8(str); +#elif PY_VERSION_HEX >= 0x03000000 char *newstr = 0; str = PyUnicode_AsUTF8String(str); if (str) { char *cstr; Py_ssize_t len; - PyBytes_AsStringAndSize(str, &cstr, &len); - newstr = (char *) malloc(len+1); - memcpy(newstr, cstr, len+1); + if (PyBytes_AsStringAndSize(str, &cstr, &len) != -1) { + newstr = (char *) malloc(len+1); + if (newstr) + memcpy(newstr, cstr, len+1); + } Py_XDECREF(str); } return newstr; @@ -825,10 +829,10 @@ SWIG_Python_str_AsChar(PyObject *str) #endif } -#if PY_VERSION_HEX >= 0x03000000 -# define SWIG_Python_str_DelForPy3(x) free( (void*) (x) ) +#if PY_VERSION_HEX >= 0x03030000 || PY_VERSION_HEX < 0x03000000 +# define SWIG_Python_str_DelForPy3(x) #else -# define SWIG_Python_str_DelForPy3(x) +# define SWIG_Python_str_DelForPy3(x) free( (void*) (x) ) #endif @@ -1243,6 +1247,19 @@ SWIG_Python_UnpackTuple(PyObject *args, const char *name, Py_ssize_t min, Py_ssi } } +SWIGINTERN int +SWIG_Python_CheckNoKeywords(PyObject *kwargs, const char *name) { + int no_kwargs = 1; + if (kwargs) { + assert(PyDict_Check(kwargs)); + if (PyDict_Size(kwargs) > 0) { + PyErr_Format(PyExc_TypeError, "%s() does not take keyword arguments", name); + no_kwargs = 0; + } + } + return no_kwargs; +} + /* A functor is a function object with one single object argument */ #define SWIG_Python_CallFunctor(functor, obj) PyObject_CallFunctionObjArgs(functor, obj, NULL); @@ -1756,6 +1773,12 @@ SwigPyObject_TypeOnce(void) { #if PY_VERSION_HEX >= 0x03040000 0, /* tp_finalize */ #endif +#if PY_VERSION_HEX >= 0x03080000 + 0, /* tp_vectorcall */ +#endif +#if (PY_VERSION_HEX >= 0x03080000) && (PY_VERSION_HEX < 0x03090000) + 0, /* tp_print */ +#endif #ifdef COUNT_ALLOCS 0, /* tp_allocs */ 0, /* tp_frees */ @@ -1917,6 +1940,12 @@ SwigPyPacked_TypeOnce(void) { #if PY_VERSION_HEX >= 0x03040000 0, /* tp_finalize */ #endif +#if PY_VERSION_HEX >= 0x03080000 + 0, /* tp_vectorcall */ +#endif +#if (PY_VERSION_HEX >= 0x03080000) && (PY_VERSION_HEX < 0x03090000) + 0, /* tp_print */ +#endif #ifdef COUNT_ALLOCS 0, /* tp_allocs */ 0, /* tp_frees */ @@ -2243,8 +2272,10 @@ SWIG_Python_NewShadowInstance(SwigPyClientData *data, PyObject *swig_this) } } #else - PyObject *key = SWIG_This(); - PyObject_SetAttr(inst, key, swig_this); + if (PyObject_SetAttr(inst, SWIG_This(), swig_this) == -1) { + Py_DECREF(inst); + inst = 0; + } #endif } } else { @@ -2256,8 +2287,12 @@ SWIG_Python_NewShadowInstance(SwigPyClientData *data, PyObject *swig_this) inst = ((PyTypeObject *)data->newargs)->tp_new((PyTypeObject *)data->newargs, empty_args, empty_kwargs); Py_DECREF(empty_kwargs); if (inst) { - PyObject_SetAttr(inst, SWIG_This(), swig_this); - Py_TYPE(inst)->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG; + if (PyObject_SetAttr(inst, SWIG_This(), swig_this) == -1) { + Py_DECREF(inst); + inst = 0; + } else { + Py_TYPE(inst)->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG; + } } } Py_DECREF(empty_args); @@ -2274,25 +2309,21 @@ SWIG_Python_NewShadowInstance(SwigPyClientData *data, PyObject *swig_this) return inst; } -SWIGRUNTIME void +SWIGRUNTIME int SWIG_Python_SetSwigThis(PyObject *inst, PyObject *swig_this) { - PyObject *dict; #if !defined(SWIG_PYTHON_SLOW_GETSET_THIS) - PyObject **dictptr = _PyObject_GetDictPtr(inst); - if (dictptr != NULL) { - dict = *dictptr; - if (dict == NULL) { - dict = PyDict_New(); - *dictptr = dict; - } - PyDict_SetItem(dict, SWIG_This(), swig_this); - return; - } + PyObject **dictptr = _PyObject_GetDictPtr(inst); + if (dictptr != NULL) { + PyObject *dict = *dictptr; + if (dict == NULL) { + dict = PyDict_New(); + *dictptr = dict; + } + return PyDict_SetItem(dict, SWIG_This(), swig_this); + } #endif - dict = PyObject_GetAttrString(inst, "__dict__"); - PyDict_SetItem(dict, SWIG_This(), swig_this); - Py_DECREF(dict); + return PyObject_SetAttr(inst, SWIG_This(), swig_this); } @@ -2306,7 +2337,8 @@ SWIG_Python_InitShadowInstance(PyObject *args) { if (sthis) { SwigPyObject_append((PyObject*) sthis, obj[1]); } else { - SWIG_Python_SetSwigThis(obj[0], obj[1]); + if (SWIG_Python_SetSwigThis(obj[0], obj[1]) != 0) + return NULL; } return SWIG_Py_Void(); } @@ -2666,10 +2698,9 @@ SWIGINTERN PyObject *SWIG_PyStaticMethod_New(PyObject *SWIGUNUSEDPARM(self), PyO #define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[3] #define SWIGTYPE_p_std__string swig_types[4] #define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[5] -#define SWIGTYPE_p_std__vectorT_int_t swig_types[6] -#define SWIGTYPE_p_std__vectorT_std__string_t swig_types[7] -static swig_type_info *swig_types[9]; -static swig_module_info swig_module = {swig_types, 8, 0, 0, 0, 0}; +#define SWIGTYPE_p_std__vectorT_std__string_t swig_types[6] +static swig_type_info *swig_types[8]; +static swig_module_info swig_module = {swig_types, 7, 0, 0, 0, 0}; #define SWIG_TypeQuery(name) SWIG_TypeQueryModule(&swig_module, &swig_module, name) #define SWIG_MangledTypeQuery(name) SWIG_MangledTypeQueryModule(&swig_module, &swig_module, name) @@ -2692,7 +2723,7 @@ static swig_module_info swig_module = {swig_types, 8, 0, 0, 0, 0}; #endif #define SWIG_name "_sentencepiece" -#define SWIGVERSION 0x040001 +#define SWIGVERSION 0x040002 #define SWIG_VERSION SWIGVERSION @@ -2975,9 +3006,11 @@ SWIG_AsCharPtrAndSize(PyObject *obj, char** cptr, size_t* psize, int *alloc) if (alloc) *alloc = SWIG_NEWOBJ; #endif - PyBytes_AsStringAndSize(obj, &cstr, &len); + if (PyBytes_AsStringAndSize(obj, &cstr, &len) == -1) + return SWIG_TypeError; #else - PyString_AsStringAndSize(obj, &cstr, &len); + if (PyString_AsStringAndSize(obj, &cstr, &len) == -1) + return SWIG_TypeError; #endif if (cptr) { if (alloc) { @@ -3285,22 +3318,6 @@ SWIGINTERNINLINE PyObject* SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_LoadFromFile(sentencepiece::SentencePieceProcessor *self,absl::string_view arg){ return self->Load(arg); } -SWIGINTERN std::string sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ - for (int id : ids) - if (id < 0 || id >= self->GetPieceSize()) - throw sentencepiece::util::Status( - sentencepiece::util::StatusCode::kOutOfRange, - "piece id is out of range."); - return self->DecodeIds(ids); - } -SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ - for (int id : ids) - if (id < 0 || id >= self->GetPieceSize()) - throw sentencepiece::util::Status( - sentencepiece::util::StatusCode::kOutOfRange, - "piece id is out of range."); - return self->DecodeIdsAsSerializedProto(ids); - } SWIGINTERN int SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val) @@ -4911,125 +4928,6 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor_LoadFromFile(PyObject *SWIGUNU } -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - std::vector< int > *arg2 = 0 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - std::string result; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsWithCheck", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - std::vector *out = nullptr; - if (PyList_Check(swig_obj[1])) { - const size_t size = PyList_Size(swig_obj[1]); - out = new std::vector(size); - for (size_t i = 0; i < size; ++i) { - PyObject *o = PyList_GetItem(swig_obj[1], i); - if (PyInt_Check(o)) { - (*out)[i] = static_cast(PyInt_AsLong(o)); - } else { - PyErr_SetString(PyExc_TypeError,"list must contain integers"); - SWIG_fail; - } - } - } else { - PyErr_SetString(PyExc_TypeError,"not a list"); - SWIG_fail; - } - arg2 = out; - } - { - try { - result = sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - { - PyObject *input_type = resultobj; - resultobj = MakePyOutputString(result, input_type); - } - { - delete arg2; - } - return resultobj; -fail: - { - delete arg2; - } - return NULL; -} - - -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - std::vector< int > *arg2 = 0 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - sentencepiece::util::bytes result; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - std::vector *out = nullptr; - if (PyList_Check(swig_obj[1])) { - const size_t size = PyList_Size(swig_obj[1]); - out = new std::vector(size); - for (size_t i = 0; i < size; ++i) { - PyObject *o = PyList_GetItem(swig_obj[1], i); - if (PyInt_Check(o)) { - (*out)[i] = static_cast(PyInt_AsLong(o)); - } else { - PyErr_SetString(PyExc_TypeError,"list must contain integers"); - SWIG_fail; - } - } - } else { - PyErr_SetString(PyExc_TypeError,"not a list"); - SWIG_fail; - } - arg2 = out; - } - { - try { - result = sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - { - resultobj = MakePyOutputBytes(result); - } - { - delete arg2; - } - return resultobj; -fail: - { - delete arg2; - } - return NULL; -} - - SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *obj; if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL; @@ -5397,8 +5295,6 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_pad_id", _wrap_SentencePieceProcessor_pad_id, METH_O, NULL}, { "SentencePieceProcessor_serialized_model_proto", _wrap_SentencePieceProcessor_serialized_model_proto, METH_O, NULL}, { "SentencePieceProcessor_LoadFromFile", _wrap_SentencePieceProcessor_LoadFromFile, METH_VARARGS, NULL}, - { "SentencePieceProcessor_DecodeIdsWithCheck", _wrap_SentencePieceProcessor_DecodeIdsWithCheck, METH_VARARGS, NULL}, - { "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL}, @@ -5424,7 +5320,6 @@ static swig_type_info _swigt__p_sentencepiece__SentencePieceProcessor = {"_p_sen static swig_type_info _swigt__p_sentencepiece__SentencePieceTrainer = {"_p_sentencepiece__SentencePieceTrainer", "sentencepiece::SentencePieceTrainer *", 0, 0, (void*)0, 0}; static swig_type_info _swigt__p_std__string = {"_p_std__string", "sentencepiece::util::bytes *|std::string *", 0, 0, (void*)0, 0}; static swig_type_info _swigt__p_std__unordered_mapT_std__string_std__string_t = {"_p_std__unordered_mapT_std__string_std__string_t", "std::unordered_map< std::string,std::string > *", 0, 0, (void*)0, 0}; -static swig_type_info _swigt__p_std__vectorT_int_t = {"_p_std__vectorT_int_t", "std::vector< int > *", 0, 0, (void*)0, 0}; static swig_type_info _swigt__p_std__vectorT_std__string_t = {"_p_std__vectorT_std__string_t", "std::vector< std::string > *", 0, 0, (void*)0, 0}; static swig_type_info *swig_type_initial[] = { @@ -5434,7 +5329,6 @@ static swig_type_info *swig_type_initial[] = { &_swigt__p_sentencepiece__SentencePieceTrainer, &_swigt__p_std__string, &_swigt__p_std__unordered_mapT_std__string_std__string_t, - &_swigt__p_std__vectorT_int_t, &_swigt__p_std__vectorT_std__string_t, }; @@ -5444,7 +5338,6 @@ static swig_cast_info _swigc__p_sentencepiece__SentencePieceProcessor[] = { {&_ static swig_cast_info _swigc__p_sentencepiece__SentencePieceTrainer[] = { {&_swigt__p_sentencepiece__SentencePieceTrainer, 0, 0, 0},{0, 0, 0, 0}}; static swig_cast_info _swigc__p_std__string[] = { {&_swigt__p_std__string, 0, 0, 0},{0, 0, 0, 0}}; static swig_cast_info _swigc__p_std__unordered_mapT_std__string_std__string_t[] = { {&_swigt__p_std__unordered_mapT_std__string_std__string_t, 0, 0, 0},{0, 0, 0, 0}}; -static swig_cast_info _swigc__p_std__vectorT_int_t[] = { {&_swigt__p_std__vectorT_int_t, 0, 0, 0},{0, 0, 0, 0}}; static swig_cast_info _swigc__p_std__vectorT_std__string_t[] = { {&_swigt__p_std__vectorT_std__string_t, 0, 0, 0},{0, 0, 0, 0}}; static swig_cast_info *swig_cast_initial[] = { @@ -5454,7 +5347,6 @@ static swig_cast_info *swig_cast_initial[] = { _swigc__p_sentencepiece__SentencePieceTrainer, _swigc__p_std__string, _swigc__p_std__unordered_mapT_std__string_std__string_t, - _swigc__p_std__vectorT_int_t, _swigc__p_std__vectorT_std__string_t, }; @@ -5860,6 +5752,12 @@ extern "C" { #if PY_VERSION_HEX >= 0x03040000 0, /* tp_finalize */ #endif +#if PY_VERSION_HEX >= 0x03080000 + 0, /* tp_vectorcall */ +#endif +#if (PY_VERSION_HEX >= 0x03080000) && (PY_VERSION_HEX < 0x03090000) + 0, /* tp_print */ +#endif #ifdef COUNT_ALLOCS 0, /* tp_allocs */ 0, /* tp_frees */ diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index b367eb60..e4e9d4a0 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include #include #include @@ -22,7 +24,6 @@ #include "model_interface.h" #include "normalizer.h" #include "sentencepiece.pb.h" -#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -627,8 +628,13 @@ util::Status SentencePieceProcessor::Decode( util::Status SentencePieceProcessor::Decode(const std::vector &ids, SentencePieceText *spt) const { std::vector pieces; + const int num_pieces = GetPieceSize(); pieces.reserve(ids.size()); for (const int id : ids) { + if (id < 0 || id >= num_pieces) { + return util::Status(util::StatusCode::kOutOfRange, + absl::StrCat("Invalid id: ", id)); + } pieces.emplace_back(IdToPiece(id)); } return Decode(pieces, spt); diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 91379736..e10a47c5 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include #include "builder.h" @@ -20,7 +22,6 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" -#include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/container/flat_hash_map.h" @@ -741,6 +742,8 @@ TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { return kMap[id]; } + int GetPieceSize() const override { return 256; } + bool IsUnknown(int id) const override { return (id == 0); } bool IsControl(int id) const override { return (id == 1 || id == 2); } @@ -1136,6 +1139,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ("cba", output); } + // Out of range + { + std::string output; + const std::vector ids = {3, 4, 127}; + EXPECT_FALSE(sp.Decode(ids, &output).ok()); + } + { EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());