Skip to content

Commit

Permalink
added functionality to override normalizer spec
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 16, 2024
1 parent 0018af1 commit de1747b
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 2 deletions.
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.00
0.2.0
9 changes: 9 additions & 0 deletions python/src/sentencepiece/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ def _CalculateEntropy(self, text, alpha):
def _CalculateEntropyBatch(self, ins, alpha, num_threads):
return _sentencepiece.SentencePieceProcessor__CalculateEntropyBatch(self, ins, alpha, num_threads)

def _OverrideNormalizerSpec(self, args):
return _sentencepiece.SentencePieceProcessor__OverrideNormalizerSpec(self, args)

def Init(self,
model_file=None,
model_proto=None,
Expand Down Expand Up @@ -875,6 +878,12 @@ def _normalize(text):
return [_normalize(x) for x in input]
return _normalize(input)

def OverrideNormalizerSpec(self, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
new_kwargs[key] = str(value)
return self._OverrideNormalizerSpec(new_kwargs)


def piece_size(self):
return self.GetPieceSize()
Expand Down
2 changes: 1 addition & 1 deletion python/src/sentencepiece/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.00'
__version__ = '0.2.0'
20 changes: 20 additions & 0 deletions python/src/sentencepiece/sentencepiece.i
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets;

%ignore sentencepiece::SentencePieceProcessor::model_proto;
%ignore sentencepiece::SentencePieceProcessor::mutable_normalizer_spec;
%ignore sentencepiece::SentencePieceProcessor::Load;
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
%ignore sentencepiece::SentencePieceProcessor::SetModel;
Expand Down Expand Up @@ -690,6 +691,19 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return outs;
}

// override normalizer_spec
sentencepiece::util::Status _OverrideNormalizerSpec(
const std::unordered_map<std::string, std::string> &args) {
sentencepiece::util::Status status;
for (const auto &[key, value] : args) {
status = sentencepiece::SentencePieceTrainer::SetProtoField(
key, value,
$self->mutable_normalizer_spec());
if (!status.ok()) return status;
}
return status;
}

%pythoncode {
def Init(self,
model_file=None,
Expand Down Expand Up @@ -1167,6 +1181,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return [_normalize(x) for x in input]
return _normalize(input)

def OverrideNormalizerSpec(self, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
new_kwargs[key] = str(value)
return self._OverrideNormalizerSpec(new_kwargs)


def piece_size(self):
return self.GetPieceSize()
Expand Down
77 changes: 77 additions & 0 deletions python/src/sentencepiece/sentencepiece_wrap.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -4033,6 +4033,16 @@ SWIGINTERN std::vector< float > sentencepiece_SentencePieceProcessor__CalculateE
}
return outs;
}
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(sentencepiece::SentencePieceProcessor *self,std::unordered_map< std::string,std::string > const &args){
sentencepiece::util::Status status;
for (const auto &[key, value] : args) {
status = sentencepiece::SentencePieceTrainer::SetProtoField(
key, value,
self->mutable_normalizer_spec());
if (!status.ok()) return status;
}
return status;
}

SWIGINTERN int
SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val)
Expand Down Expand Up @@ -8508,6 +8518,72 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor__CalculateEntropyBatch(PyObjec
}


SWIGINTERN PyObject *_wrap_SentencePieceProcessor__OverrideNormalizerSpec(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
std::unordered_map< std::string,std::string > *arg2 = 0 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;

if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__OverrideNormalizerSpec", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__OverrideNormalizerSpec" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[1])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[1], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
if (key_ustring.IsAvalable() && value_ustring.IsAvalable()) {
out->emplace(std::string(key_ustring.data(), key_ustring.size()),
std::string(value_ustring.data(), value_ustring.size()));
} else {
PyErr_SetString(PyExc_TypeError, "map must contain strings.");
SWIG_fail;
}
resultobj = key_ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a dictionary");
SWIG_fail;
}
arg2 = out;
}
{
try {
result = sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(arg1,(std::unordered_map< std::string,std::string > const &)*arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
{
delete arg2;
}
return resultobj;
fail:
{
delete arg2;
}
return NULL;
}


SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *obj;
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
Expand Down Expand Up @@ -9362,6 +9438,7 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL},
{ "SentencePieceProcessor__OverrideNormalizerSpec", _wrap_SentencePieceProcessor__OverrideNormalizerSpec, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
{ "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},
Expand Down
17 changes: 17 additions & 0 deletions python/test/sentencepiece_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,23 @@ def test_normalizer_rule(self):
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
self.assertEqual('abc', sp.Normalize('ABC'))

def test_override_normalize_spec(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)

self.assertEqual(
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world']
)

sp.override_normalizer_spec(add_dummy_prefix=False)
sp.override_normalizer_spec(remove_extra_whitespaces=False)
sp.override_normalizer_spec(escape_whitespaces=False)
self.assertEqual(
sp.EncodeAsPieces(' hello world '),
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '],
)


def suite():
suite = unittest.TestSuite()
Expand Down
4 changes: 4 additions & 0 deletions src/sentencepiece_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,10 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
return model_proto_ ? model_proto_->SerializeAsString() : "";
}

NormalizerSpec *SentencePieceProcessor::mutable_normalizer_spec() const {
return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr;
}

// Set seed value of random generator.
// Do not set static_cast<unique_int>(-1),
// as this seed is reserved for initializing from
Expand Down
6 changes: 6 additions & 0 deletions src/sentencepiece_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class NBestSentencePieceText;
class ModelInterface;
class SentencePieceText;
class ModelProto;
class NormalizerSpec;

namespace normalizer {
class Normalizer;
Expand Down Expand Up @@ -692,6 +693,11 @@ class SentencePieceProcessor {
// Useful to save the state of this instance via Python's pickle object.
util::bytes serialized_model_proto() const;

// Returns mutable normalizer_spec.
// Updating the intenral normalization during the encoding/decoding are not
// recommended and may result in unexpected behavior. Use at your own risk.
NormalizerSpec *mutable_normalizer_spec() const;

private:
enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };

Expand Down

0 comments on commit de1747b

Please # to comment.