From 9841d00e7d95e559b23c56a892c012934d10b3a6 Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Thu, 21 Sep 2017 11:37:03 +0100 Subject: [PATCH 1/8] add python interface --- python/__init__.py | 59 +++++ python/src/module.cpp | 162 +++++++++++++ python/src/range_coder_interface.cpp | 346 +++++++++++++++++++++++++++ python/src/range_coder_interface.h | 52 ++++ python/tests/test_range_coder.py | 229 ++++++++++++++++++ setup.py | 17 ++ 6 files changed, 865 insertions(+) create mode 100644 python/__init__.py create mode 100644 python/src/module.cpp create mode 100644 python/src/range_coder_interface.cpp create mode 100644 python/src/range_coder_interface.h create mode 100644 python/tests/test_range_coder.py create mode 100644 setup.py diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 0000000..0018533 --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,59 @@ +from warnings import warn +from range_coder._range_coder import RangeEncoder, RangeDecoder # noqa: F401 + +try: + import numpy as np +except ImportError: + pass + + +def prob_to_cum_freq(prob, resolution=1024): + """ + Converts probability distribution into a cumulative frequency table. + + Makes sure that non-zero probabilities are represented by non-zero frequencies, + provided that :samp:`len({prob}) <= {resolution}`. + + Parameters + ---------- + prob : ndarray or list + A one-dimensional array representing a probability distribution + + resolution : int + Number of hypothetical samples used to generate integer frequencies + + Returns + ------- + list + Cumulative frequency table + """ + + if len(prob) > resolution: + warn('Resolution smaller than number of symbols.') + + prob = np.asarray(prob, dtype=np.float64) + freq = np.zeros(prob.size, dtype=int) + + # this is similar to gradient descent in KL divergence (convex) + with np.errstate(divide='ignore', invalid='ignore'): + for _ in range(resolution): + freq[np.nanargmax(prob / freq)] += 1 + + return [0] + np.cumsum(freq).tolist() + + +def cum_freq_to_prob(cumFreq): + """ + Converts a cumulative frequency table into a probability distribution. + + Parameters + ---------- + cumFreq : list + Cumulative frequency table + + Returns + ------- + ndarray + Probability distribution + """ + return np.diff(cumFreq).astype(np.float64) / cumFreq[-1] diff --git a/python/src/module.cpp b/python/src/module.cpp new file mode 100644 index 0000000..9790df5 --- /dev/null +++ b/python/src/module.cpp @@ -0,0 +1,162 @@ +#include +#include "range_coder_interface.h" + +static PyMethodDef RangeEncoder_methods[] = { + {"encode", + (PyCFunction)RangeEncoder_encode, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_encode_doc}, + {"close", + (PyCFunction)RangeEncoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_close_doc}, + {0} +}; + + +static PyGetSetDef RangeEncoder_getset[] = { + {0} +}; + + +PyTypeObject RangeEncoder_type = { + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeEncoder", /*tp_name*/ + sizeof(RangeEncoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeEncoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeEncoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeEncoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeEncoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeEncoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeEncoder_new, /*tp_new*/ +}; + +static PyMethodDef RangeDecoder_methods[] = { + {"decode", + (PyCFunction)RangeDecoder_decode, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_decode_doc}, + {"close", + (PyCFunction)RangeDecoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_close_doc}, + {0} +}; + + +static PyGetSetDef RangeDecoder_getset[] = { + {0} +}; + + +PyTypeObject RangeDecoder_type = { + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeDecoder", /*tp_name*/ + sizeof(RangeDecoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeDecoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequdece*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeDecoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeDecoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeDecoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeDecoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeDecoder_new, /*tp_new*/ +}; + +#if PY_MAJOR_VERSION >= 3 +static PyModuleDef range_coder_module = { + PyModuleDef_HEAD_INIT, + "_range_coder", + "A fast implementation of a range encoder and decoder." + -1, 0, 0, 0, 0, 0 +}; +#endif + + +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC PyInit__range_coder() { + // create module object + PyObject* module = PyModule_Create(&range_coder_module); +#define RETVAL 0; +#else +PyMODINIT_FUNC init_range_coder() { + PyObject* module = Py_InitModule3( + "_range_coder", 0, "A fast implementation of a range encoder and decoder."); +#define RETVAL void(); +#endif + + if(!module) + return RETVAL; + + // initialize types + if(PyType_Ready(&RangeEncoder_type) < 0) + return RETVAL; + if(PyType_Ready(&RangeDecoder_type) < 0) + return RETVAL; + + // add types to module + Py_INCREF(&RangeEncoder_type); + PyModule_AddObject(module, "RangeEncoder", reinterpret_cast(&RangeEncoder_type)); + Py_INCREF(&RangeDecoder_type); + PyModule_AddObject(module, "RangeDecoder", reinterpret_cast(&RangeDecoder_type)); + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} diff --git a/python/src/range_coder_interface.cpp b/python/src/range_coder_interface.cpp new file mode 100644 index 0000000..8e0df33 --- /dev/null +++ b/python/src/range_coder_interface.cpp @@ -0,0 +1,346 @@ +#include "range_coder_interface.h" +#include + +PyObject* RangeEncoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + + if (self) + reinterpret_cast(self)->encoder = 0; + + return self; +} + + +PyObject* RangeDecoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + + if (self) + reinterpret_cast(self)->decoder = 0; + + return self; +} + + +const char* RangeEncoder_doc = "A fast implementation of a range encoder."; + +int RangeEncoder_init(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; + + self->fout = new std::ofstream(filepath, std::ios::out | std::ios::binary); + self->iter = new OutputIterator(*(self->fout)); + self->encoder = new rc_encoder_t(*(self->iter)); + + return 0; +} + + +const char* RangeDecoder_doc = "A fast implementation of a range decoder."; + +int RangeDecoder_init(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; + + + self->fin = new std::ifstream(filepath, std::ios::in | std::ios::binary); + self->begin = new InputIterator(*(self->fin)); + self->end = new InputIterator(); + self->decoder = new rc_decoder_t(*(self->begin), *(self->end)); + + return 0; +} + + +void RangeEncoder_dealloc(RangeEncoderObject* self) { + if (self->encoder) { + // flush buffer + self->encoder->final(); + + delete self->encoder; + delete self->iter; + delete self->fout; + } + + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + + +void RangeDecoder_dealloc(RangeDecoderObject* self) { + if (self->decoder) { + delete self->decoder; + delete self->begin; + delete self->end; + delete self->fin; + } + + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + + +const char* RangeEncoder_encode_doc = + "encode(data, cumFreq)\n" + "\n" + "Encodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "data : list[int]\n" + " A list of positive integers representing indices into cumulative frequency table\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies"; + +PyObject* RangeEncoder_encode(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"data", "cumFreq", 0}; + + PyObject* data; + PyObject* cumFreq; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast(kwlist), &data, &cumFreq)) + return 0; + + if (!self->fout->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } + + if (!PyList_Check(data)) { + PyErr_SetString(PyExc_TypeError, "`data` needs to be a list."); + return 0; + } + + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list."); + return 0; + } + + // load cumulative frequency table + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } + + unsigned long* cumFreqArr = new unsigned long[cumFreqLen]; + + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i)); + + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and non-decreasing."); + + if (PyErr_Occurred()) { + delete[] cumFreqArr; + return 0; + } + } + + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + delete[] cumFreqArr; + return 0; + } + + if(cumFreqArr[cumFreqLen - 1] > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } + + // load data + Py_ssize_t dataLen = PyList_Size(data); + Py_ssize_t* dataArr = new Py_ssize_t[dataLen]; + + for (Py_ssize_t i = 0; i < dataLen; ++i) { + #if PY_MAJOR_VERSION >= 3 + dataArr[i] = PyLong_AsSsize_t(PyList_GetItem(data, i)); + #else + dataArr[i] = PyInt_AsSsize_t(PyList_GetItem(data, i)); + #endif + + if (!PyErr_Occurred() and dataArr[i] > cumFreqLen - 2) + PyErr_SetString(PyExc_ValueError, "An entry in `data` is too large or `cumFreq` is too short."); + + if (PyErr_Occurred()) + { + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + } + + bool positive = true; + for (int i = 0; i < cumFreqLen - 1; ++i) { + if (cumFreqArr[i] == cumFreqArr[i + 1]) { + positive = false; + break; + } + } + + // encode data + if (positive) { + // no extra checks necessary + for (int i = 0; i < dataLen; ++i) { + self->encoder->encode( + cumFreqArr[dataArr[i]], + cumFreqArr[dataArr[i] + 1], + cumFreqArr[cumFreqLen - 1]); + } + } else { + for (int i = 0; i < dataLen; ++i) { + unsigned long start = cumFreqArr[dataArr[i]]; + unsigned long end = cumFreqArr[dataArr[i] + 1]; + + if (start == end) { + PyErr_SetString(PyExc_ValueError, "Cannot encode symbol due to zero frequency."); + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + + self->encoder->encode(start, end, cumFreqArr[cumFreqLen - 1]); + } + } + + delete[] cumFreqArr; + delete[] dataArr; + + Py_INCREF(Py_None); + return Py_None; +} + + +const char* RangeDecoder_decode_doc = + "decode(size, cumFreq)\n" + "\n" + "Decodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "size : int\n" + " Number of symbols to decode\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies\n" + "\n" + "Returns\n" + "-------\n" + "list[int]\n" + " List of decoded indices"; + +PyObject* RangeDecoder_decode(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"size", "cumFreq", 0}; + + Py_ssize_t size; + PyObject* cumFreq; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "nO", const_cast(kwlist), &size, &cumFreq)) + return 0; + + if (!self->fin->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } + + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list of frequencies."); + return 0; + } + + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } + + if (cumFreqLen > MAX_NUM_SYMBOLS + 1) { + PyErr_SetString(PyExc_ValueError, "Number of symbols can be at most MAX_NUM_SYMBOLS."); + return 0; + } + + unsigned long resolution = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, cumFreqLen - 1)); + + if (PyErr_Occurred()) + return 0; + + if(resolution > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } + + // load cumulative frequency table + SearchType::freq_type cumFreqArr[MAX_NUM_SYMBOLS + 1]; + + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = static_cast( + PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i))); + + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and increasing."); + + if (PyErr_Occurred()) + return 0; + } + + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + return 0; + } + + // fill up remaining entries + for (Py_ssize_t i = cumFreqLen; i < MAX_NUM_SYMBOLS + 1; ++i) + cumFreqArr[i] = cumFreqArr[cumFreqLen - 1]; + + // decode data + PyObject* data = PyList_New(size); + for (Py_ssize_t i = 0; i < size; ++i) { + rc_type_t::uint index = self->decoder->decode(cumFreqArr[MAX_NUM_SYMBOLS], cumFreqArr); + + PyList_SetItem(data, i, PyLong_FromUnsignedLong(index)); + + if (PyErr_Occurred()) { + Py_DECREF(data); + return 0; + } + } + + return data; +} + + +const char* RangeEncoder_close_doc = + "close()\n" + "\n" + "Write remaining bits in buffer and close file."; + +PyObject* RangeEncoder_close(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + self->encoder->final(); + self->fout->close(); + + Py_INCREF(Py_None); + return Py_None; +} + + +const char* RangeDecoder_close_doc = + "close()\n" + "\n" + "Close file."; + +PyObject* RangeDecoder_close(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + self->fin->close(); + + Py_INCREF(Py_None); + return Py_None; +} diff --git a/python/src/range_coder_interface.h b/python/src/range_coder_interface.h new file mode 100644 index 0000000..85c688a --- /dev/null +++ b/python/src/range_coder_interface.h @@ -0,0 +1,52 @@ +#ifndef RANGE_CODER_INTERFACE_H +#define RANGE_CODER_INTERFACE_H + +#include +#include +#include +#include "range_coder.hpp" + +#define MAX_NUM_SYMBOLS 1024 + +extern const char* RangeEncoder_doc; +extern const char* RangeEncoder_encode_doc; +extern const char* RangeEncoder_close_doc; +extern const char* RangeDecoder_doc; +extern const char* RangeDecoder_decode_doc; +extern const char* RangeDecoder_close_doc; + +typedef std::ostream_iterator OutputIterator; +typedef std::istreambuf_iterator InputIterator; +typedef rc_decoder_search_t SearchType; + +struct RangeEncoderObject { + PyObject_HEAD + rc_encoder_t* encoder; + OutputIterator* iter; + std::ofstream* fout; +}; + +struct RangeDecoderObject { + PyObject_HEAD + rc_decoder_t* decoder; + InputIterator* begin; + InputIterator* end; + std::ifstream* fin; +}; + +PyObject* RangeEncoder_new(PyTypeObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_new(PyTypeObject*, PyObject*, PyObject*); + +int RangeEncoder_init(RangeEncoderObject*, PyObject*, PyObject*); +int RangeDecoder_init(RangeDecoderObject*, PyObject*, PyObject*); + +void RangeEncoder_dealloc(RangeEncoderObject*); +void RangeDecoder_dealloc(RangeDecoderObject*); + +PyObject* RangeEncoder_encode(RangeEncoderObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_decode(RangeDecoderObject*, PyObject*, PyObject*); + +PyObject* RangeEncoder_close(RangeEncoderObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_close(RangeDecoderObject*, PyObject*, PyObject*); + +#endif diff --git a/python/tests/test_range_coder.py b/python/tests/test_range_coder.py new file mode 100644 index 0000000..3ee9d34 --- /dev/null +++ b/python/tests/test_range_coder.py @@ -0,0 +1,229 @@ +import os +import random +import sys +from tempfile import mkstemp + +import numpy as np +import pytest + +from range_coder import RangeEncoder, RangeDecoder +from range_coder import prob_to_cum_freq, cum_freq_to_prob + + +def test_range_coder_overflow(): + """ + Cumulative frequencies must fit in an unsigned integer (assumed to be represented by 32 bits). + This test checks that no error is thrown if the frequencies exceed that limit. + """ + + numBytes = 17 + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + prob = [4, 6, 8] + prob = np.asarray(prob, dtype=np.float64) / np.sum(prob) + cumFreq = prob_to_cum_freq(prob, 128) + cumFreq[-1] = 2**32 + + sequence = [2, 2] + data = sequence * numBytes + + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + encoder.encode(data, cumFreq) + encoder.close() + + +def test_range_encoder(): + """ + Tests that RangeEncoder writes the expected bits. + + Tests that writing after closing file throws an exception. + """ + + numBytes = 17 + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + cumFreq = [0, 4, 6, 8] + sequence = [0, 0, 0, 0, 1, 2] + sequenceByte = b'\x0b' + data = sequence * numBytes + + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + with pytest.raises(RuntimeError): + # file is already closed, should raise an exception + encoder.encode(sequence, cumFreq) + + assert os.stat(filepath).st_size == numBytes + + with open(filepath, 'rb') as handle: + # the first 4 bytes are special + handle.read(4) + + for _ in range(numBytes - 4): + assert handle.read(1) == sequenceByte + + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + # cumFreq contains negative frequencies + encoder.encode(data, [-1, 1]) + with pytest.raises(ValueError): + # cumFreq does not start at zero + encoder.encode(data, [1, 2, 3]) + with pytest.raises(ValueError): + # cumFreq too short + encoder.encode(data, [0, 1]) + with pytest.raises(ValueError): + # symbols with zero probability cannot be encoded + encoder.encode(data, [0, 8, 8, 8]) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, []) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, [0]) + encoder.close() + + os.remove(filepath) + + +def test_range_decoder(): + """ + Tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. + """ + + random.seed(558) + + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + cumFreq0 = [0, 4, 6, 8] + cumFreq1 = [0, 2, 5, 7, 10, 14] + data0 = [random.randint(0, len(cumFreq0) - 2) for _ in range(10)] + data1 = [random.randint(0, len(cumFreq1) - 2) for _ in range(17)] + + encoder = RangeEncoder(filepath) + encoder.encode(data0, cumFreq0) + encoder.encode(data1, cumFreq1) + encoder.close() + + decoder = RangeDecoder(filepath) + dataRec0 = decoder.decode(len(data0), cumFreq0) + dataRec1 = decoder.decode(len(data1), cumFreq1) + decoder.close() + + # encoded and decoded data should be the same + assert data0 == dataRec0 + assert data1 == dataRec1 + + # make sure reference counting is implemented correctly (call to getrefcount increases it by 1) + assert sys.getrefcount(dataRec0) == 2 + assert sys.getrefcount(dataRec1) == 2 + + decoder = RangeDecoder(filepath) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), []) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), [0]) + assert decoder.decode(0, cumFreq0) == [] + + os.remove(filepath) + + +def test_range_decoder_fuzz(): + """ + Test random inputs to the decoder. + """ + + random.seed(827) + randomState = np.random.RandomState(827) + + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + + # decode random symbols + decoder = RangeDecoder('/dev/urandom') + decoder.decode(100, cumFreq) + + +def test_range_encoder_fuzz(): + """ + Test random inputs to the encoder. + """ + + random.seed(111) + randomState = np.random.RandomState(111) + + filepath = mkstemp()[1] + + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + + # encode random symbols + dataLen = randomState.randint(0, 10) + data = [random.randint(0, numSymbols - 1) for _ in range(dataLen)] + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + os.remove(filepath) + + +def test_prob_to_cum_freq(): + """ + Tests whether prob_to_cum_freq produces a table with the expected number + of entries, number of samples, and that non-zero probabilities are + represented by non-zero increases in frequency. + + Tests that cum_freq_to_prob is normalized and consistent with prob_to_cum_freq. + """ + + randomState = np.random.RandomState(190) + resolution = 1024 + + p0 = randomState.dirichlet([.1] * 50) + cumFreq0 = prob_to_cum_freq(p0, resolution) + p1 = cum_freq_to_prob(cumFreq0) + cumFreq1 = prob_to_cum_freq(p1, resolution) + + # number of hypothetical samples should correspond to resolution + assert cumFreq0[-1] == resolution + assert len(cumFreq0) == len(p0) + 1 + + # non-zero probabilities should have non-zero frequencies + assert np.all(np.diff(cumFreq0)[p0 > 0.] > 0) + + # probabilities should be normalized. + assert np.isclose(np.sum(p1), 1.) + + # while the probabilities might change, frequencies should not + assert cumFreq0 == cumFreq1 + + +def test_prob_to_cum_freq_zero_prob(): + """ + Tests whether prob_to_cum_freq handles zero probabilities as expected. + """ + + prob1 = [0.5, 0.25, 0.25] + cumFreq1 = prob_to_cum_freq(prob1, resolution=8) + + prob0 = [0.5, 0., 0.25, 0.25, 0., 0.] + cumFreq0 = prob_to_cum_freq(prob0, resolution=8) + + # removing entries corresponding to zeros + assert [cumFreq0[0]] + [cumFreq0[i + 1] for i, p in enumerate(prob0) if p > 0.] == cumFreq1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a6cfe24 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, Extension + +setup( + name="range_coder", + version="1.0", + description="A fast implementation of a range coder", + packages=["range_coder"], + package_dir={"range_coder": "python"}, + license="BSD", + ext_modules=[ + Extension("range_coder._range_coder", + language="c++", + include_dirs=["."], + sources=[ + "python/src/range_coder_interface.cpp", + "python/src/module.cpp" + ])]) From caf365f708aa180eb62899927ab782dc110072c8 Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Thu, 21 Sep 2017 11:37:36 +0100 Subject: [PATCH 2/8] fix indentation --- range_coder.hpp | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/range_coder.hpp b/range_coder.hpp index a617e91..e97612f 100644 --- a/range_coder.hpp +++ b/range_coder.hpp @@ -1,11 +1,11 @@ /* - * Copyrgght (c) 2006, Daisuke Okanohara + * Copyright (c) 2006, Daisuke Okanohara * Copyright (c) 2008-2010, Cybozu Labs, Inc. * All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: - * + * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, @@ -14,7 +14,7 @@ * * Neither the name of the copyright holders nor the names of its * contributors may be used to endorse or promote products derived from this * software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE @@ -63,28 +63,28 @@ template class rc_encoder_t : public rc_type_t { } uint newL = L + r*low; if (newL < L) { - //overflow occured (newL >= 2^32) - //buffer FF FF .. FF -> buffer+1 00 00 .. 00 + // overflow occured (newL >= 2^32) + // buffer FF FF .. FF -> buffer+1 00 00 .. 00 buffer++; for (;carryN > 0; carryN--) { - *iter++ = buffer; - buffer = 0; + *iter++ = buffer; + buffer = 0; } } L = newL; while (R < TOP) { const byte newBuffer = (L >> 24); if (start) { - buffer = newBuffer; - start = false; + buffer = newBuffer; + start = false; } else if (newBuffer == 0xFF) { - carryN++; + carryN++; } else { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; - } - buffer = newBuffer; + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } + buffer = newBuffer; } L <<= 8; R <<= 8; @@ -101,7 +101,7 @@ template class rc_encoder_t : public rc_type_t { uint t8 = t >> 24, l8 = L >> 24; *iter++ = l8; if (t8 != l8) { - break; + break; } t <<= 8; L <<= 8; @@ -151,7 +151,7 @@ template struct rc_decoder_search_t : public rc_de __m128i b = _mm_cmplt_epi16(v, y); mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a); if (mask) { - return i + (__builtin_ctz(mask) >> 1) - 1; + return i + (__builtin_ctz(mask) >> 1) - 1; } } return 255; @@ -175,25 +175,25 @@ template class rc_decoder_t : public rc_type_ uint decode(const uint total, const freq_type* cumFreq) { const uint r = R / total; const int targetPos = std::min(total-1, D / r); - - //find target s.t. cumFreq[target] <= targetPos < cumFreq[target+1] + + // find target s.t. cumFreq[target] <= targetPos < cumFreq[target+1] const uint target = search_type::get_index(cumFreq, targetPos + search_type::BASE); const uint low = cumFreq[target] - search_type::BASE; const uint high = cumFreq[target+1] - search_type::BASE; - + D -= r * low; if (high != total) { R = r * (high-low); } else { R -= r * low; } - + while (R < TOP) { R <<= 8; D = (D << 8) | next(); } - + return target; } byte next() { From 64bb08d65998f3e6e18c36450c34bb21c383dc64 Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Thu, 3 Oct 2019 21:15:16 +0100 Subject: [PATCH 3/8] add readme with example --- README.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..56da9e7 --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +Installation +============ + + pip install range-coder + + +Example +======= + +```python +from range_coder import RangeEncoder, RangeDecoder, prob_to_cum_freq + +data = [2, 0, 1, 0, 0, 0, 1, 2, 2] +prob = [0.5, 0.2, 0.3] + +# convert probabilities to cumulative integer frequency table +cumFreq = prob_to_cum_freq(prob, resolution=128) + +# encode data +encoder = RangeEncoder(filepath) +encoder.encode(data, cumFreq) +encoder.close() + +# decode data +decoder = RangeDecoder(filepath) +dataRec = decoder.decode(len(data), cumFreq) +decoder.close() +``` From 3e66d80961619695d154ba935570f20a37e8e49c Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Sun, 9 Feb 2020 20:01:51 +0100 Subject: [PATCH 4/8] replace spaces with tabs and make it a bit more readable --- python/src/module.cpp | 234 ++++++------ python/src/range_coder_interface.cpp | 517 +++++++++++++-------------- python/src/range_coder_interface.h | 18 +- python/tests/test_range_coder.py | 360 ++++++++++--------- range_coder.hpp | 315 ++++++++-------- setup.py | 29 +- 6 files changed, 757 insertions(+), 716 deletions(-) diff --git a/python/src/module.cpp b/python/src/module.cpp index 9790df5..17f4e21 100644 --- a/python/src/module.cpp +++ b/python/src/module.cpp @@ -2,161 +2,161 @@ #include "range_coder_interface.h" static PyMethodDef RangeEncoder_methods[] = { - {"encode", - (PyCFunction)RangeEncoder_encode, - METH_VARARGS | METH_KEYWORDS, - RangeEncoder_encode_doc}, - {"close", - (PyCFunction)RangeEncoder_close, - METH_VARARGS | METH_KEYWORDS, - RangeEncoder_close_doc}, - {0} + {"encode", + (PyCFunction)RangeEncoder_encode, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_encode_doc}, + {"close", + (PyCFunction)RangeEncoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_close_doc}, + {0} }; static PyGetSetDef RangeEncoder_getset[] = { - {0} + {0} }; PyTypeObject RangeEncoder_type = { - PyVarObject_HEAD_INIT(0, 0) - "range_coder.RangeEncoder", /*tp_name*/ - sizeof(RangeEncoderObject), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - (destructor)RangeEncoder_dealloc, /*tp_dealloc*/ - 0, /*tp_print*/ - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - 0, /*tp_compare*/ - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequence*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash */ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT, /*tp_flags*/ - RangeEncoder_doc, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - RangeEncoder_methods, /*tp_methods*/ - 0, /*tp_members*/ - RangeEncoder_getset, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - (initproc)RangeEncoder_init, /*tp_init*/ - 0, /*tp_alloc*/ - RangeEncoder_new, /*tp_new*/ + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeEncoder", /*tp_name*/ + sizeof(RangeEncoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeEncoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeEncoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeEncoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeEncoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeEncoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeEncoder_new, /*tp_new*/ }; static PyMethodDef RangeDecoder_methods[] = { - {"decode", - (PyCFunction)RangeDecoder_decode, - METH_VARARGS | METH_KEYWORDS, - RangeDecoder_decode_doc}, - {"close", - (PyCFunction)RangeDecoder_close, - METH_VARARGS | METH_KEYWORDS, - RangeDecoder_close_doc}, - {0} + {"decode", + (PyCFunction)RangeDecoder_decode, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_decode_doc}, + {"close", + (PyCFunction)RangeDecoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_close_doc}, + {0} }; static PyGetSetDef RangeDecoder_getset[] = { - {0} + {0} }; PyTypeObject RangeDecoder_type = { - PyVarObject_HEAD_INIT(0, 0) - "range_coder.RangeDecoder", /*tp_name*/ - sizeof(RangeDecoderObject), /*tp_basicsize*/ - 0, /*tp_itemsize*/ - (destructor)RangeDecoder_dealloc, /*tp_dealloc*/ - 0, /*tp_print*/ - 0, /*tp_getattr*/ - 0, /*tp_setattr*/ - 0, /*tp_compare*/ - 0, /*tp_repr*/ - 0, /*tp_as_number*/ - 0, /*tp_as_sequdece*/ - 0, /*tp_as_mapping*/ - 0, /*tp_hash */ - 0, /*tp_call*/ - 0, /*tp_str*/ - 0, /*tp_getattro*/ - 0, /*tp_setattro*/ - 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT, /*tp_flags*/ - RangeDecoder_doc, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ - 0, /*tp_richcompare*/ - 0, /*tp_weaklistoffset*/ - 0, /*tp_iter*/ - 0, /*tp_iternext*/ - RangeDecoder_methods, /*tp_methods*/ - 0, /*tp_members*/ - RangeDecoder_getset, /*tp_getset*/ - 0, /*tp_base*/ - 0, /*tp_dict*/ - 0, /*tp_descr_get*/ - 0, /*tp_descr_set*/ - 0, /*tp_dictoffset*/ - (initproc)RangeDecoder_init, /*tp_init*/ - 0, /*tp_alloc*/ - RangeDecoder_new, /*tp_new*/ + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeDecoder", /*tp_name*/ + sizeof(RangeDecoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeDecoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequdece*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeDecoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeDecoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeDecoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeDecoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeDecoder_new, /*tp_new*/ }; #if PY_MAJOR_VERSION >= 3 static PyModuleDef range_coder_module = { - PyModuleDef_HEAD_INIT, - "_range_coder", - "A fast implementation of a range encoder and decoder." - -1, 0, 0, 0, 0, 0 + PyModuleDef_HEAD_INIT, + "_range_coder", + "A fast implementation of a range encoder and decoder." + -1, 0, 0, 0, 0, 0 }; #endif #if PY_MAJOR_VERSION >= 3 PyMODINIT_FUNC PyInit__range_coder() { - // create module object - PyObject* module = PyModule_Create(&range_coder_module); + // create module object + PyObject* module = PyModule_Create(&range_coder_module); #define RETVAL 0; #else PyMODINIT_FUNC init_range_coder() { - PyObject* module = Py_InitModule3( - "_range_coder", 0, "A fast implementation of a range encoder and decoder."); + PyObject* module = Py_InitModule3( + "_range_coder", 0, "A fast implementation of a range encoder and decoder."); #define RETVAL void(); #endif - if(!module) - return RETVAL; + if(!module) + return RETVAL; - // initialize types - if(PyType_Ready(&RangeEncoder_type) < 0) - return RETVAL; - if(PyType_Ready(&RangeDecoder_type) < 0) - return RETVAL; + // initialize types + if(PyType_Ready(&RangeEncoder_type) < 0) + return RETVAL; + if(PyType_Ready(&RangeDecoder_type) < 0) + return RETVAL; - // add types to module - Py_INCREF(&RangeEncoder_type); - PyModule_AddObject(module, "RangeEncoder", reinterpret_cast(&RangeEncoder_type)); - Py_INCREF(&RangeDecoder_type); - PyModule_AddObject(module, "RangeDecoder", reinterpret_cast(&RangeDecoder_type)); + // add types to module + Py_INCREF(&RangeEncoder_type); + PyModule_AddObject(module, "RangeEncoder", reinterpret_cast(&RangeEncoder_type)); + Py_INCREF(&RangeDecoder_type); + PyModule_AddObject(module, "RangeDecoder", reinterpret_cast(&RangeDecoder_type)); #if PY_MAJOR_VERSION >= 3 - return module; + return module; #endif } diff --git a/python/src/range_coder_interface.cpp b/python/src/range_coder_interface.cpp index 8e0df33..29ec626 100644 --- a/python/src/range_coder_interface.cpp +++ b/python/src/range_coder_interface.cpp @@ -2,345 +2,344 @@ #include PyObject* RangeEncoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { - PyObject* self = type->tp_alloc(type, 0); + PyObject* self = type->tp_alloc(type, 0); - if (self) - reinterpret_cast(self)->encoder = 0; + if (self) + reinterpret_cast(self)->encoder = 0; - return self; + return self; } PyObject* RangeDecoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { - PyObject* self = type->tp_alloc(type, 0); + PyObject* self = type->tp_alloc(type, 0); - if (self) - reinterpret_cast(self)->decoder = 0; + if (self) + reinterpret_cast(self)->decoder = 0; - return self; + return self; } const char* RangeEncoder_doc = "A fast implementation of a range encoder."; int RangeEncoder_init(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { - const char* kwlist[] = {"filepath", 0}; - const char* filepath = 0; + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) - return -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; - self->fout = new std::ofstream(filepath, std::ios::out | std::ios::binary); - self->iter = new OutputIterator(*(self->fout)); - self->encoder = new rc_encoder_t(*(self->iter)); + self->fout = new std::ofstream(filepath, std::ios::out | std::ios::binary); + self->iter = new OutputIterator(*(self->fout)); + self->encoder = new rc_encoder_t(*(self->iter)); - return 0; + return 0; } const char* RangeDecoder_doc = "A fast implementation of a range decoder."; int RangeDecoder_init(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { - const char* kwlist[] = {"filepath", 0}; - const char* filepath = 0; + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) - return -1; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; + self->fin = new std::ifstream(filepath, std::ios::in | std::ios::binary); + self->begin = new InputIterator(*(self->fin)); + self->end = new InputIterator(); + self->decoder = new rc_decoder_t(*(self->begin), *(self->end)); - self->fin = new std::ifstream(filepath, std::ios::in | std::ios::binary); - self->begin = new InputIterator(*(self->fin)); - self->end = new InputIterator(); - self->decoder = new rc_decoder_t(*(self->begin), *(self->end)); - - return 0; + return 0; } void RangeEncoder_dealloc(RangeEncoderObject* self) { - if (self->encoder) { - // flush buffer - self->encoder->final(); + if (self->encoder) { + // flush buffer + self->encoder->final(); - delete self->encoder; - delete self->iter; - delete self->fout; - } + delete self->encoder; + delete self->iter; + delete self->fout; + } - Py_TYPE(self)->tp_free(reinterpret_cast(self)); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); } void RangeDecoder_dealloc(RangeDecoderObject* self) { - if (self->decoder) { - delete self->decoder; - delete self->begin; - delete self->end; - delete self->fin; - } - - Py_TYPE(self)->tp_free(reinterpret_cast(self)); + if (self->decoder) { + delete self->decoder; + delete self->begin; + delete self->end; + delete self->fin; + } + + Py_TYPE(self)->tp_free(reinterpret_cast(self)); } const char* RangeEncoder_encode_doc = - "encode(data, cumFreq)\n" - "\n" - "Encodes a list of indices using the given cumulative frequency table.\n" - "\n" - "The length of the frequency table should be the number of possible symbols plus one.\n" - "\n" - "Parameters\n" - "----------\n" - "data : list[int]\n" - " A list of positive integers representing indices into cumulative frequency table\n" - "\n" - "cumFreq : list[int]\n" - " List of increasing positive integers representing cumulative frequencies"; + "encode(data, cumFreq)\n" + "\n" + "Encodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "data : list[int]\n" + " A list of positive integers representing indices into cumulative frequency table\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies"; PyObject* RangeEncoder_encode(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { - const char* kwlist[] = {"data", "cumFreq", 0}; - - PyObject* data; - PyObject* cumFreq; - - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast(kwlist), &data, &cumFreq)) - return 0; - - if (!self->fout->is_open()) { - PyErr_SetString(PyExc_RuntimeError, "File closed."); - return 0; - } - - if (!PyList_Check(data)) { - PyErr_SetString(PyExc_TypeError, "`data` needs to be a list."); - return 0; - } - - if (!PyList_Check(cumFreq)) { - PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list."); - return 0; - } - - // load cumulative frequency table - Py_ssize_t cumFreqLen = PyList_Size(cumFreq); - - if (cumFreqLen < 2) { - PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); - return 0; - } - - unsigned long* cumFreqArr = new unsigned long[cumFreqLen]; - - for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { - cumFreqArr[i] = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i)); - - if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) - PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and non-decreasing."); - - if (PyErr_Occurred()) { - delete[] cumFreqArr; - return 0; - } - } - - if (cumFreqArr[0] != 0) { - PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); - delete[] cumFreqArr; - return 0; - } - - if(cumFreqArr[cumFreqLen - 1] > std::numeric_limits::max()) { - PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); - return 0; - } - - // load data - Py_ssize_t dataLen = PyList_Size(data); - Py_ssize_t* dataArr = new Py_ssize_t[dataLen]; - - for (Py_ssize_t i = 0; i < dataLen; ++i) { - #if PY_MAJOR_VERSION >= 3 - dataArr[i] = PyLong_AsSsize_t(PyList_GetItem(data, i)); - #else - dataArr[i] = PyInt_AsSsize_t(PyList_GetItem(data, i)); - #endif - - if (!PyErr_Occurred() and dataArr[i] > cumFreqLen - 2) - PyErr_SetString(PyExc_ValueError, "An entry in `data` is too large or `cumFreq` is too short."); - - if (PyErr_Occurred()) - { - delete[] cumFreqArr; - delete[] dataArr; - return 0; - } - } - - bool positive = true; - for (int i = 0; i < cumFreqLen - 1; ++i) { - if (cumFreqArr[i] == cumFreqArr[i + 1]) { - positive = false; - break; - } - } - - // encode data - if (positive) { - // no extra checks necessary - for (int i = 0; i < dataLen; ++i) { - self->encoder->encode( - cumFreqArr[dataArr[i]], - cumFreqArr[dataArr[i] + 1], - cumFreqArr[cumFreqLen - 1]); - } - } else { - for (int i = 0; i < dataLen; ++i) { - unsigned long start = cumFreqArr[dataArr[i]]; - unsigned long end = cumFreqArr[dataArr[i] + 1]; - - if (start == end) { - PyErr_SetString(PyExc_ValueError, "Cannot encode symbol due to zero frequency."); - delete[] cumFreqArr; - delete[] dataArr; - return 0; - } - - self->encoder->encode(start, end, cumFreqArr[cumFreqLen - 1]); - } - } - - delete[] cumFreqArr; - delete[] dataArr; - - Py_INCREF(Py_None); - return Py_None; + const char* kwlist[] = {"data", "cumFreq", 0}; + + PyObject* data; + PyObject* cumFreq; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast(kwlist), &data, &cumFreq)) + return 0; + + if (!self->fout->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } + + if (!PyList_Check(data)) { + PyErr_SetString(PyExc_TypeError, "`data` needs to be a list."); + return 0; + } + + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list."); + return 0; + } + + // load cumulative frequency table + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } + + unsigned long* cumFreqArr = new unsigned long[cumFreqLen]; + + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i)); + + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and non-decreasing."); + + if (PyErr_Occurred()) { + delete[] cumFreqArr; + return 0; + } + } + + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + delete[] cumFreqArr; + return 0; + } + + if(cumFreqArr[cumFreqLen - 1] > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } + + // load data + Py_ssize_t dataLen = PyList_Size(data); + Py_ssize_t* dataArr = new Py_ssize_t[dataLen]; + + for (Py_ssize_t i = 0; i < dataLen; ++i) { + #if PY_MAJOR_VERSION >= 3 + dataArr[i] = PyLong_AsSsize_t(PyList_GetItem(data, i)); + #else + dataArr[i] = PyInt_AsSsize_t(PyList_GetItem(data, i)); + #endif + + if (!PyErr_Occurred() and dataArr[i] > cumFreqLen - 2) + PyErr_SetString(PyExc_ValueError, "An entry in `data` is too large or `cumFreq` is too short."); + + if (PyErr_Occurred()) + { + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + } + + bool positive = true; + for (int i = 0; i < cumFreqLen - 1; ++i) { + if (cumFreqArr[i] == cumFreqArr[i + 1]) { + positive = false; + break; + } + } + + // encode data + if (positive) { + // no extra checks necessary + for (int i = 0; i < dataLen; ++i) { + self->encoder->encode( + cumFreqArr[dataArr[i]], + cumFreqArr[dataArr[i] + 1], + cumFreqArr[cumFreqLen - 1]); + } + } else { + for (int i = 0; i < dataLen; ++i) { + unsigned long start = cumFreqArr[dataArr[i]]; + unsigned long end = cumFreqArr[dataArr[i] + 1]; + + if (start == end) { + PyErr_SetString(PyExc_ValueError, "Cannot encode symbol due to zero frequency."); + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + + self->encoder->encode(start, end, cumFreqArr[cumFreqLen - 1]); + } + } + + delete[] cumFreqArr; + delete[] dataArr; + + Py_INCREF(Py_None); + return Py_None; } const char* RangeDecoder_decode_doc = - "decode(size, cumFreq)\n" - "\n" - "Decodes a list of indices using the given cumulative frequency table.\n" - "\n" - "The length of the frequency table should be the number of possible symbols plus one.\n" - "\n" - "Parameters\n" - "----------\n" - "size : int\n" - " Number of symbols to decode\n" - "\n" - "cumFreq : list[int]\n" - " List of increasing positive integers representing cumulative frequencies\n" - "\n" - "Returns\n" - "-------\n" - "list[int]\n" - " List of decoded indices"; + "decode(size, cumFreq)\n" + "\n" + "Decodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "size : int\n" + " Number of symbols to decode\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies\n" + "\n" + "Returns\n" + "-------\n" + "list[int]\n" + " List of decoded indices"; PyObject* RangeDecoder_decode(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { - const char* kwlist[] = {"size", "cumFreq", 0}; + const char* kwlist[] = {"size", "cumFreq", 0}; - Py_ssize_t size; - PyObject* cumFreq; + Py_ssize_t size; + PyObject* cumFreq; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "nO", const_cast(kwlist), &size, &cumFreq)) - return 0; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "nO", const_cast(kwlist), &size, &cumFreq)) + return 0; - if (!self->fin->is_open()) { - PyErr_SetString(PyExc_RuntimeError, "File closed."); - return 0; - } + if (!self->fin->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } - if (!PyList_Check(cumFreq)) { - PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list of frequencies."); - return 0; - } + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list of frequencies."); + return 0; + } - Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); - if (cumFreqLen < 2) { - PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); - return 0; - } + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } - if (cumFreqLen > MAX_NUM_SYMBOLS + 1) { - PyErr_SetString(PyExc_ValueError, "Number of symbols can be at most MAX_NUM_SYMBOLS."); - return 0; - } + if (cumFreqLen > MAX_NUM_SYMBOLS + 1) { + PyErr_SetString(PyExc_ValueError, "Number of symbols can be at most MAX_NUM_SYMBOLS."); + return 0; + } - unsigned long resolution = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, cumFreqLen - 1)); + unsigned long resolution = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, cumFreqLen - 1)); - if (PyErr_Occurred()) - return 0; + if (PyErr_Occurred()) + return 0; - if(resolution > std::numeric_limits::max()) { - PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); - return 0; - } + if(resolution > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } - // load cumulative frequency table - SearchType::freq_type cumFreqArr[MAX_NUM_SYMBOLS + 1]; + // load cumulative frequency table + SearchType::freq_type cumFreqArr[MAX_NUM_SYMBOLS + 1]; - for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { - cumFreqArr[i] = static_cast( - PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i))); + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = static_cast( + PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i))); - if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) - PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and increasing."); + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and increasing."); - if (PyErr_Occurred()) - return 0; - } + if (PyErr_Occurred()) + return 0; + } - if (cumFreqArr[0] != 0) { - PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); - return 0; - } + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + return 0; + } - // fill up remaining entries - for (Py_ssize_t i = cumFreqLen; i < MAX_NUM_SYMBOLS + 1; ++i) - cumFreqArr[i] = cumFreqArr[cumFreqLen - 1]; + // fill up remaining entries + for (Py_ssize_t i = cumFreqLen; i < MAX_NUM_SYMBOLS + 1; ++i) + cumFreqArr[i] = cumFreqArr[cumFreqLen - 1]; - // decode data - PyObject* data = PyList_New(size); - for (Py_ssize_t i = 0; i < size; ++i) { - rc_type_t::uint index = self->decoder->decode(cumFreqArr[MAX_NUM_SYMBOLS], cumFreqArr); + // decode data + PyObject* data = PyList_New(size); + for (Py_ssize_t i = 0; i < size; ++i) { + rc_type_t::uint index = self->decoder->decode(cumFreqArr[MAX_NUM_SYMBOLS], cumFreqArr); - PyList_SetItem(data, i, PyLong_FromUnsignedLong(index)); + PyList_SetItem(data, i, PyLong_FromUnsignedLong(index)); - if (PyErr_Occurred()) { - Py_DECREF(data); - return 0; - } - } + if (PyErr_Occurred()) { + Py_DECREF(data); + return 0; + } + } - return data; + return data; } const char* RangeEncoder_close_doc = - "close()\n" - "\n" - "Write remaining bits in buffer and close file."; + "close()\n" + "\n" + "Write remaining bits in buffer and close file."; PyObject* RangeEncoder_close(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { - self->encoder->final(); - self->fout->close(); + self->encoder->final(); + self->fout->close(); - Py_INCREF(Py_None); - return Py_None; + Py_INCREF(Py_None); + return Py_None; } const char* RangeDecoder_close_doc = - "close()\n" - "\n" - "Close file."; + "close()\n" + "\n" + "Close file."; PyObject* RangeDecoder_close(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { - self->fin->close(); + self->fin->close(); - Py_INCREF(Py_None); - return Py_None; + Py_INCREF(Py_None); + return Py_None; } diff --git a/python/src/range_coder_interface.h b/python/src/range_coder_interface.h index 85c688a..9957d7e 100644 --- a/python/src/range_coder_interface.h +++ b/python/src/range_coder_interface.h @@ -20,18 +20,18 @@ typedef std::istreambuf_iterator InputIterator; typedef rc_decoder_search_t SearchType; struct RangeEncoderObject { - PyObject_HEAD - rc_encoder_t* encoder; - OutputIterator* iter; - std::ofstream* fout; + PyObject_HEAD + rc_encoder_t* encoder; + OutputIterator* iter; + std::ofstream* fout; }; struct RangeDecoderObject { - PyObject_HEAD - rc_decoder_t* decoder; - InputIterator* begin; - InputIterator* end; - std::ifstream* fin; + PyObject_HEAD + rc_decoder_t* decoder; + InputIterator* begin; + InputIterator* end; + std::ifstream* fin; }; PyObject* RangeEncoder_new(PyTypeObject*, PyObject*, PyObject*); diff --git a/python/tests/test_range_coder.py b/python/tests/test_range_coder.py index 3ee9d34..27e95df 100644 --- a/python/tests/test_range_coder.py +++ b/python/tests/test_range_coder.py @@ -11,219 +11,245 @@ def test_range_coder_overflow(): - """ - Cumulative frequencies must fit in an unsigned integer (assumed to be represented by 32 bits). - This test checks that no error is thrown if the frequencies exceed that limit. - """ + """ + Cumulative frequencies must fit in an unsigned integer (assumed to be represented by 32 bits). + This test checks that no error is thrown if the frequencies exceed that limit. + """ - numBytes = 17 - filepath = mkstemp()[1] + numBytes = 17 + filepath = mkstemp()[1] - # encoding one sequence should require 1 byte - prob = [4, 6, 8] - prob = np.asarray(prob, dtype=np.float64) / np.sum(prob) - cumFreq = prob_to_cum_freq(prob, 128) - cumFreq[-1] = 2**32 + # encoding one sequence should require 1 byte + prob = [4, 6, 8] + prob = np.asarray(prob, dtype=np.float64) / np.sum(prob) + cumFreq = prob_to_cum_freq(prob, 128) + cumFreq[-1] = 2**32 - sequence = [2, 2] - data = sequence * numBytes + sequence = [2, 2] + data = sequence * numBytes - encoder = RangeEncoder(filepath) - with pytest.raises(OverflowError): - encoder.encode(data, cumFreq) - encoder.close() + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + encoder.encode(data, cumFreq) + encoder.close() def test_range_encoder(): - """ - Tests that RangeEncoder writes the expected bits. - - Tests that writing after closing file throws an exception. - """ - - numBytes = 17 - filepath = mkstemp()[1] - - # encoding one sequence should require 1 byte - cumFreq = [0, 4, 6, 8] - sequence = [0, 0, 0, 0, 1, 2] - sequenceByte = b'\x0b' - data = sequence * numBytes - - encoder = RangeEncoder(filepath) - encoder.encode(data, cumFreq) - encoder.close() - - with pytest.raises(RuntimeError): - # file is already closed, should raise an exception - encoder.encode(sequence, cumFreq) - - assert os.stat(filepath).st_size == numBytes - - with open(filepath, 'rb') as handle: - # the first 4 bytes are special - handle.read(4) - - for _ in range(numBytes - 4): - assert handle.read(1) == sequenceByte - - encoder = RangeEncoder(filepath) - with pytest.raises(OverflowError): - # cumFreq contains negative frequencies - encoder.encode(data, [-1, 1]) - with pytest.raises(ValueError): - # cumFreq does not start at zero - encoder.encode(data, [1, 2, 3]) - with pytest.raises(ValueError): - # cumFreq too short - encoder.encode(data, [0, 1]) - with pytest.raises(ValueError): - # symbols with zero probability cannot be encoded - encoder.encode(data, [0, 8, 8, 8]) - with pytest.raises(ValueError): - # invalid frequency table - encoder.encode(data, []) - with pytest.raises(ValueError): - # invalid frequency table - encoder.encode(data, [0]) - encoder.close() - - os.remove(filepath) + """ + Tests that RangeEncoder writes the expected bits. + + Tests that writing after closing file throws an exception. + """ + + numBytes = 17 + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + cumFreq = [0, 4, 6, 8] + sequence = [0, 0, 0, 0, 1, 2] + sequenceByte = b'\x0b' + data = sequence * numBytes + + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + with pytest.raises(RuntimeError): + # file is already closed, should raise an exception + encoder.encode(sequence, cumFreq) + + assert os.stat(filepath).st_size == numBytes + + with open(filepath, 'rb') as handle: + # the first 4 bytes are special + handle.read(4) + + for _ in range(numBytes - 4): + assert handle.read(1) == sequenceByte + + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + # cumFreq contains negative frequencies + encoder.encode(data, [-1, 1]) + with pytest.raises(ValueError): + # cumFreq does not start at zero + encoder.encode(data, [1, 2, 3]) + with pytest.raises(ValueError): + # cumFreq too short + encoder.encode(data, [0, 1]) + with pytest.raises(ValueError): + # symbols with zero probability cannot be encoded + encoder.encode(data, [0, 8, 8, 8]) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, []) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, [0]) + encoder.close() + + os.remove(filepath) def test_range_decoder(): - """ - Tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. - """ + """ + Tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. + """ - random.seed(558) + random.seed(0) - filepath = mkstemp()[1] + filepath = mkstemp()[1] - # encoding one sequence should require 1 byte - cumFreq0 = [0, 4, 6, 8] - cumFreq1 = [0, 2, 5, 7, 10, 14] - data0 = [random.randint(0, len(cumFreq0) - 2) for _ in range(10)] - data1 = [random.randint(0, len(cumFreq1) - 2) for _ in range(17)] + # encoding one sequence should require 1 byte + cumFreq0 = [0, 4, 6, 8] + cumFreq1 = [0, 2, 5, 7, 10, 14] + data0 = [random.randint(0, len(cumFreq0) - 2) for _ in range(10)] + data1 = [random.randint(0, len(cumFreq1) - 2) for _ in range(17)] - encoder = RangeEncoder(filepath) - encoder.encode(data0, cumFreq0) - encoder.encode(data1, cumFreq1) - encoder.close() + encoder = RangeEncoder(filepath) + encoder.encode(data0, cumFreq0) + encoder.encode(data1, cumFreq1) + encoder.close() - decoder = RangeDecoder(filepath) - dataRec0 = decoder.decode(len(data0), cumFreq0) - dataRec1 = decoder.decode(len(data1), cumFreq1) - decoder.close() + decoder = RangeDecoder(filepath) + dataRec0 = decoder.decode(len(data0), cumFreq0) + dataRec1 = decoder.decode(len(data1), cumFreq1) + decoder.close() - # encoded and decoded data should be the same - assert data0 == dataRec0 - assert data1 == dataRec1 + # encoded and decoded data should be the same + assert data0 == dataRec0 + assert data1 == dataRec1 - # make sure reference counting is implemented correctly (call to getrefcount increases it by 1) - assert sys.getrefcount(dataRec0) == 2 - assert sys.getrefcount(dataRec1) == 2 + # make sure reference counting is implemented correctly (call to getrefcount increases it by 1) + assert sys.getrefcount(dataRec0) == 2 + assert sys.getrefcount(dataRec1) == 2 - decoder = RangeDecoder(filepath) - with pytest.raises(ValueError): - # invalid frequency table - decoder.decode(len(data0), []) - with pytest.raises(ValueError): - # invalid frequency table - decoder.decode(len(data0), [0]) - assert decoder.decode(0, cumFreq0) == [] + decoder = RangeDecoder(filepath) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), []) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), [0]) + assert decoder.decode(0, cumFreq0) == [] - os.remove(filepath) + os.remove(filepath) + + +def test_range_encoder_decoder(): + """ + Additional tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. + """ + + random.seed(0) + + filepath = mkstemp()[1] + + cumFreq = [0.25, 0.25, 0.5] + + data = [0, 1, 2] + + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + decoder = RangeDecoder(filepath) + dataRec = decoder.decode(len(data), cumFreq) + decoder.close() + + assert data == dataRec + + os.remove(filepath) def test_range_decoder_fuzz(): - """ - Test random inputs to the decoder. - """ + """ + Test whether random inputs to the decoder throw any errors. + """ - random.seed(827) - randomState = np.random.RandomState(827) + random.seed(1) + randomState = np.random.RandomState(827) - for _ in range(10): - # generate random frequency table - numSymbols = random.randint(1, 20) - maxFreq = random.randint(2, 100) - cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) - cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int - # decode random symbols - decoder = RangeDecoder('/dev/urandom') - decoder.decode(100, cumFreq) + # decode random symbols + decoder = RangeDecoder('/dev/urandom') + decoder.decode(100, cumFreq) def test_range_encoder_fuzz(): - """ - Test random inputs to the encoder. - """ + """ + Test whether random inputs to the encoder throw any errors. + """ - random.seed(111) - randomState = np.random.RandomState(111) + random.seed(2) + randomState = np.random.RandomState(111) - filepath = mkstemp()[1] + filepath = mkstemp()[1] - for _ in range(10): - # generate random frequency table - numSymbols = random.randint(1, 20) - maxFreq = random.randint(2, 100) - cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) - cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int - # encode random symbols - dataLen = randomState.randint(0, 10) - data = [random.randint(0, numSymbols - 1) for _ in range(dataLen)] - encoder = RangeEncoder(filepath) - encoder.encode(data, cumFreq) - encoder.close() + # encode random symbols + dataLen = randomState.randint(0, 10) + data = [random.randint(0, numSymbols - 1) for _ in range(dataLen)] + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() - os.remove(filepath) + os.remove(filepath) def test_prob_to_cum_freq(): - """ - Tests whether prob_to_cum_freq produces a table with the expected number - of entries, number of samples, and that non-zero probabilities are - represented by non-zero increases in frequency. + """ + Tests whether prob_to_cum_freq produces a table with the expected number + of entries, number of samples, and that non-zero probabilities are + represented by non-zero increases in frequency. - Tests that cum_freq_to_prob is normalized and consistent with prob_to_cum_freq. - """ + Tests that cum_freq_to_prob is normalized and consistent with prob_to_cum_freq. + """ - randomState = np.random.RandomState(190) - resolution = 1024 + randomState = np.random.RandomState(190) + resolution = 1024 - p0 = randomState.dirichlet([.1] * 50) - cumFreq0 = prob_to_cum_freq(p0, resolution) - p1 = cum_freq_to_prob(cumFreq0) - cumFreq1 = prob_to_cum_freq(p1, resolution) + p0 = randomState.dirichlet([.1] * 50) + cumFreq0 = prob_to_cum_freq(p0, resolution) + p1 = cum_freq_to_prob(cumFreq0) + cumFreq1 = prob_to_cum_freq(p1, resolution) - # number of hypothetical samples should correspond to resolution - assert cumFreq0[-1] == resolution - assert len(cumFreq0) == len(p0) + 1 + # number of hypothetical samples should correspond to resolution + assert cumFreq0[-1] == resolution + assert len(cumFreq0) == len(p0) + 1 - # non-zero probabilities should have non-zero frequencies - assert np.all(np.diff(cumFreq0)[p0 > 0.] > 0) + # non-zero probabilities should have non-zero frequencies + assert np.all(np.diff(cumFreq0)[p0 > 0.] > 0) - # probabilities should be normalized. - assert np.isclose(np.sum(p1), 1.) + # probabilities should be normalized. + assert np.isclose(np.sum(p1), 1.) - # while the probabilities might change, frequencies should not - assert cumFreq0 == cumFreq1 + # while the probabilities might change, frequencies should not + assert cumFreq0 == cumFreq1 def test_prob_to_cum_freq_zero_prob(): - """ - Tests whether prob_to_cum_freq handles zero probabilities as expected. - """ + """ + Tests whether prob_to_cum_freq handles zero probabilities as expected. + """ - prob1 = [0.5, 0.25, 0.25] - cumFreq1 = prob_to_cum_freq(prob1, resolution=8) + prob1 = [0.5, 0.25, 0.25] + cumFreq1 = prob_to_cum_freq(prob1, resolution=8) - prob0 = [0.5, 0., 0.25, 0.25, 0., 0.] - cumFreq0 = prob_to_cum_freq(prob0, resolution=8) + prob0 = [0.5, 0., 0.25, 0.25, 0., 0.] + cumFreq0 = prob_to_cum_freq(prob0, resolution=8) - # removing entries corresponding to zeros - assert [cumFreq0[0]] + [cumFreq0[i + 1] for i, p in enumerate(prob0) if p > 0.] == cumFreq1 + # removing entries corresponding to zeros + assert [cumFreq0[0]] + [cumFreq0[i + 1] for i, p in enumerate(prob0) if p > 0.] == cumFreq1 diff --git a/range_coder.hpp b/range_coder.hpp index e97612f..b6c89d7 100644 --- a/range_coder.hpp +++ b/range_coder.hpp @@ -36,173 +36,188 @@ #endif struct rc_type_t { - enum { - TOP = 1U << 24, - TOPMASK = TOP - 1, - }; - typedef unsigned int uint; - typedef unsigned char byte; + enum { + TOP = 1U << 24, + TOPMASK = TOP - 1, + }; + typedef unsigned int uint; + typedef unsigned char byte; }; template class rc_encoder_t : public rc_type_t { -public: - rc_encoder_t(const Iter &i) : iter(i) { - L = 0; - R = 0xFFFFFFFF; - buffer = 0; - carryN = 0; - counter = 0; - start = true; - } - void encode(const uint low, const uint high, const uint total) { - uint r = R / total; - if (high < total) { - R = r * (high-low); - } else { - R -= r * low; - } - uint newL = L + r*low; - if (newL < L) { - // overflow occured (newL >= 2^32) - // buffer FF FF .. FF -> buffer+1 00 00 .. 00 - buffer++; - for (;carryN > 0; carryN--) { - *iter++ = buffer; - buffer = 0; - } - } - L = newL; - while (R < TOP) { - const byte newBuffer = (L >> 24); - if (start) { - buffer = newBuffer; - start = false; - } else if (newBuffer == 0xFF) { - carryN++; - } else { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; - } - buffer = newBuffer; - } - L <<= 8; - R <<= 8; - } - counter++; - } - void final() { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; - } - uint t = L + R; - while (1) { - uint t8 = t >> 24, l8 = L >> 24; - *iter++ = l8; - if (t8 != l8) { - break; - } - t <<= 8; - L <<= 8; - } - } -private: - uint R; - uint L; - bool start; - byte buffer; - uint carryN; - Iter iter; - uint counter; + public: + rc_encoder_t(const Iter &i) : iter(i) { + L = 0; + R = 0xFFFFFFFF; + buffer = 0; + carryN = 0; + counter = 0; + start = true; + } + + void encode(const uint low, const uint high, const uint total) { + // encode symbol by adjusting range + uint r = R / total; + if (high < total) { + R = r * (high - low); + } else { + R -= r * low; + } + uint newL = L + r * low; + + if (newL < L) { + // overflow occured (newL >= 2^32) + // buffer FF FF .. FF -> buffer+1 00 00 .. 00 + buffer++; + for (; carryN > 0; carryN--) { + *iter++ = buffer; + buffer = 0; + } + } + + L = newL; + while (R < TOP) { + const byte newBuffer = (L >> 24); + if (start) { + buffer = newBuffer; + start = false; + } else if (newBuffer == 0xFF) { + carryN++; + } else { + // write left-most byte to file + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } + buffer = newBuffer; + } + L <<= 8; + R <<= 8; + } + + counter++; + } + + void final() { + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } + + uint t = L + R; + while (1) { + uint t8 = t >> 24, l8 = L >> 24; + *iter++ = l8; + if (t8 != l8) { + break; + } + t <<= 8; + L <<= 8; + } + } + + private: + uint R; + uint L; + bool start; + byte buffer; + uint carryN; + Iter iter; + uint counter; }; template struct rc_decoder_search_traits_t : public rc_type_t { - typedef FreqType freq_type; - enum { - N = _N, - BASE = _BASE - }; + typedef FreqType freq_type; + enum { + N = _N, + BASE = _BASE + }; }; template struct rc_decoder_search_t : public rc_decoder_search_traits_t { - static rc_type_t::uint get_index(const FreqType *freq, FreqType pos) { - rc_type_t::uint left = 0; - rc_type_t::uint right = _N - 1; - while(left < right) { - rc_type_t::uint mid = (left+right)/2; - if (freq[mid+1] <= pos) left = mid+1; - else right = mid; - } - return left; - } + static rc_type_t::uint get_index(const FreqType *freq, FreqType pos) { + rc_type_t::uint left = 0; + rc_type_t::uint right = _N - 1; + while (left < right) { + rc_type_t::uint mid = (left + right) / 2; + if (freq[mid+1] <= pos) { + left = mid + 1; + } else { + right = mid; + } + } + return left; + } }; #ifdef RANGE_CODER_USE_SSE - template struct rc_decoder_search_t : public rc_decoder_search_traits_t { - static rc_type_t::uint get_index(const short *freq, short pos) { - __m128i v = _mm_set1_epi16(pos); - unsigned i, mask = 0; - for (i = 0; i < 256; i += 16) { - __m128i x = *reinterpret_cast(freq + i); - __m128i y = *reinterpret_cast(freq + i + 8); - __m128i a = _mm_cmplt_epi16(v, x); - __m128i b = _mm_cmplt_epi16(v, y); - mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a); - if (mask) { - return i + (__builtin_ctz(mask) >> 1) - 1; - } - } - return 255; - } + static rc_type_t::uint get_index(const short *freq, short pos) { + __m128i v = _mm_set1_epi16(pos); + unsigned i, mask = 0; + for (i = 0; i < 256; i += 16) { + __m128i x = *reinterpret_cast(freq + i); + __m128i y = *reinterpret_cast(freq + i + 8); + __m128i a = _mm_cmplt_epi16(v, x); + __m128i b = _mm_cmplt_epi16(v, y); + mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a); + if (mask) { + return i + (__builtin_ctz(mask) >> 1) - 1; + } + } + return 255; + } }; - #endif template class rc_decoder_t : public rc_type_t { -public: - typedef SearchType search_type; - typedef typename search_type::freq_type freq_type; - static const unsigned N = search_type::N; - rc_decoder_t(const Iterator& _i, const Iterator _e) : iter(_i), iter_end(_e) { - R = 0xFFFFFFFF; - D = 0; - for (int i = 0; i < 4; i++) { - D = (D << 8) | next(); - } - } - uint decode(const uint total, const freq_type* cumFreq) { - const uint r = R / total; - const int targetPos = std::min(total-1, D / r); - - // find target s.t. cumFreq[target] <= targetPos < cumFreq[target+1] - const uint target = - search_type::get_index(cumFreq, targetPos + search_type::BASE); - const uint low = cumFreq[target] - search_type::BASE; - const uint high = cumFreq[target+1] - search_type::BASE; - - D -= r * low; - if (high != total) { - R = r * (high-low); - } else { - R -= r * low; - } - - while (R < TOP) { - R <<= 8; - D = (D << 8) | next(); - } - - return target; - } - byte next() { - return iter != iter_end ? (byte)*iter++ : 0xff; - } -private: - uint R; - uint D; - Iterator iter, iter_end; + public: + typedef SearchType search_type; + typedef typename search_type::freq_type freq_type; + static const unsigned N = search_type::N; + + rc_decoder_t(const Iterator& _i, const Iterator _e) : iter(_i), iter_end(_e) { + R = 0xFFFFFFFF; + D = 0; + + // read first four bytes from file + for (int i = 0; i < 4; i++) { + D = (D << 8) | next(); + } + } + + uint decode(const uint total, const freq_type* cumFreq) { + const uint r = R / total; + const int targetPos = std::min(total - 1, D / r); + + // find target s.t. cumFreq[target] <= targetPos < cumFreq[target + 1] + const uint target = search_type::get_index(cumFreq, targetPos + search_type::BASE); + const uint low = cumFreq[target] - search_type::BASE; + const uint high = cumFreq[target + 1] - search_type::BASE; + + D -= r * low; + if (high != total) { + R = r * (high - low); + } else { + R -= r * low; + } + + while (R < TOP) { + R <<= 8; + D = (D << 8) | next(); + } + + return target; + } + + byte next() { + return iter != iter_end ? (byte)*iter++ : 0xff; + } + + private: + uint R; + uint D; + Iterator iter, iter_end; }; #endif diff --git a/setup.py b/setup.py index a6cfe24..488fd98 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,18 @@ from setuptools import setup, Extension setup( - name="range_coder", - version="1.0", - description="A fast implementation of a range coder", - packages=["range_coder"], - package_dir={"range_coder": "python"}, - license="BSD", - ext_modules=[ - Extension("range_coder._range_coder", - language="c++", - include_dirs=["."], - sources=[ - "python/src/range_coder_interface.cpp", - "python/src/module.cpp" - ])]) + name="range_coder", + version="1.0", + description="A fast implementation of a range coder", + packages=["range_coder"], + package_dir={"range_coder": "python"}, + license="BSD", + ext_modules=[ + Extension( + "range_coder._range_coder", + language="c++", + include_dirs=["."], + sources=[ + "python/src/range_coder_interface.cpp", + "python/src/module.cpp" + ])]) From b3da9d4db41de5faa23fe193e0927ad27178e945 Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Sun, 9 Feb 2020 23:37:12 +0100 Subject: [PATCH 5/8] range_coder uses std::min from --- range_coder.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/range_coder.hpp b/range_coder.hpp index b6c89d7..0d9a574 100644 --- a/range_coder.hpp +++ b/range_coder.hpp @@ -31,6 +31,8 @@ #ifndef __RANGE_CODER_HPP__ #define __RANGE_CODER_HPP__ +#include + #ifdef RANGE_CODER_USE_SSE #include #endif From 96d71efc14c14f8f037fd0d6589e82c23685f17f Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Sun, 9 Feb 2020 23:52:02 +0100 Subject: [PATCH 6/8] fixed a bug causing incorrect reconstructions of very short sequences --- python/tests/test_range_coder.py | 23 ++++++++++++----------- range_coder.hpp | 8 +++++--- setup.py | 2 +- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/python/tests/test_range_coder.py b/python/tests/test_range_coder.py index 27e95df..51e5597 100644 --- a/python/tests/test_range_coder.py +++ b/python/tests/test_range_coder.py @@ -138,26 +138,27 @@ def test_range_decoder(): def test_range_encoder_decoder(): """ - Additional tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. + Additional tests to check whether RangeDecoder reproduces symbols encoded by RangeEncoder. """ random.seed(0) filepath = mkstemp()[1] - cumFreq = [0.25, 0.25, 0.5] + for _ in range(20): + numSymbols = np.random.randint(1, 6) + cumFreq = [0] + np.cumsum(np.random.randint(1, 10, size=numSymbols)).tolist() + data = np.random.randint(numSymbols, size=np.random.randint(20)).tolist() - data = [0, 1, 2] - - encoder = RangeEncoder(filepath) - encoder.encode(data, cumFreq) - encoder.close() + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() - decoder = RangeDecoder(filepath) - dataRec = decoder.decode(len(data), cumFreq) - decoder.close() + decoder = RangeDecoder(filepath) + dataRec = decoder.decode(len(data), cumFreq) + decoder.close() - assert data == dataRec + assert data == dataRec os.remove(filepath) diff --git a/range_coder.hpp b/range_coder.hpp index 0d9a574..c72550f 100644 --- a/range_coder.hpp +++ b/range_coder.hpp @@ -101,9 +101,11 @@ template class rc_encoder_t : public rc_type_t { } void final() { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; + if (!start) { + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } } uint t = L + R; diff --git a/setup.py b/setup.py index 488fd98..60553f9 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="range_coder", - version="1.0", + version="1.1", description="A fast implementation of a range coder", packages=["range_coder"], package_dir={"range_coder": "python"}, From 1d26c518d28b16fe1b91372a6973d2f5578f4a0a Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Thu, 7 Sep 2023 11:25:41 +0100 Subject: [PATCH 7/8] added LICENSE file --- LICENSE | 15 +++++++++++++++ setup.py | 9 ++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..83a068d --- /dev/null +++ b/LICENSE @@ -0,0 +1,15 @@ +Copyright (c) 2006, Daisuke Okanohara +Copyright (c) 2008-2010, Cybozu Labs, Inc. +Copyright (c) 2017-2023, Lucas Theis +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/setup.py b/setup.py index 60553f9..590bda8 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="range_coder", - version="1.1", + version="1.1.1", description="A fast implementation of a range coder", packages=["range_coder"], package_dir={"range_coder": "python"}, @@ -14,5 +14,8 @@ include_dirs=["."], sources=[ "python/src/range_coder_interface.cpp", - "python/src/module.cpp" - ])]) + "python/src/module.cpp", + ] + ) + ] +) From deccf77fca40417a3bfd4e7dc0d73f60c99c140b Mon Sep 17 00:00:00 2001 From: Lucas Theis Date: Tue, 17 Oct 2023 18:31:18 +0100 Subject: [PATCH 8/8] fix bug in Python3 module definition --- python/src/module.cpp | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/src/module.cpp b/python/src/module.cpp index 17f4e21..af61ba2 100644 --- a/python/src/module.cpp +++ b/python/src/module.cpp @@ -123,7 +123,7 @@ PyTypeObject RangeDecoder_type = { static PyModuleDef range_coder_module = { PyModuleDef_HEAD_INIT, "_range_coder", - "A fast implementation of a range encoder and decoder." + "A fast implementation of a range encoder and decoder.", -1, 0, 0, 0, 0, 0 }; #endif diff --git a/setup.py b/setup.py index 590bda8..c44f2f7 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="range_coder", - version="1.1.1", + version="1.1.2", description="A fast implementation of a range coder", packages=["range_coder"], package_dir={"range_coder": "python"},