diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..83a068d --- /dev/null +++ b/LICENSE @@ -0,0 +1,15 @@ +Copyright (c) 2006, Daisuke Okanohara +Copyright (c) 2008-2010, Cybozu Labs, Inc. +Copyright (c) 2017-2023, Lucas Theis +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..56da9e7 --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +Installation +============ + + pip install range-coder + + +Example +======= + +```python +from range_coder import RangeEncoder, RangeDecoder, prob_to_cum_freq + +data = [2, 0, 1, 0, 0, 0, 1, 2, 2] +prob = [0.5, 0.2, 0.3] + +# convert probabilities to cumulative integer frequency table +cumFreq = prob_to_cum_freq(prob, resolution=128) + +# encode data +encoder = RangeEncoder(filepath) +encoder.encode(data, cumFreq) +encoder.close() + +# decode data +decoder = RangeDecoder(filepath) +dataRec = decoder.decode(len(data), cumFreq) +decoder.close() +``` diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 0000000..0018533 --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,59 @@ +from warnings import warn +from range_coder._range_coder import RangeEncoder, RangeDecoder # noqa: F401 + +try: + import numpy as np +except ImportError: + pass + + +def prob_to_cum_freq(prob, resolution=1024): + """ + Converts probability distribution into a cumulative frequency table. + + Makes sure that non-zero probabilities are represented by non-zero frequencies, + provided that :samp:`len({prob}) <= {resolution}`. + + Parameters + ---------- + prob : ndarray or list + A one-dimensional array representing a probability distribution + + resolution : int + Number of hypothetical samples used to generate integer frequencies + + Returns + ------- + list + Cumulative frequency table + """ + + if len(prob) > resolution: + warn('Resolution smaller than number of symbols.') + + prob = np.asarray(prob, dtype=np.float64) + freq = np.zeros(prob.size, dtype=int) + + # this is similar to gradient descent in KL divergence (convex) + with np.errstate(divide='ignore', invalid='ignore'): + for _ in range(resolution): + freq[np.nanargmax(prob / freq)] += 1 + + return [0] + np.cumsum(freq).tolist() + + +def cum_freq_to_prob(cumFreq): + """ + Converts a cumulative frequency table into a probability distribution. + + Parameters + ---------- + cumFreq : list + Cumulative frequency table + + Returns + ------- + ndarray + Probability distribution + """ + return np.diff(cumFreq).astype(np.float64) / cumFreq[-1] diff --git a/python/src/module.cpp b/python/src/module.cpp new file mode 100644 index 0000000..af61ba2 --- /dev/null +++ b/python/src/module.cpp @@ -0,0 +1,162 @@ +#include +#include "range_coder_interface.h" + +static PyMethodDef RangeEncoder_methods[] = { + {"encode", + (PyCFunction)RangeEncoder_encode, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_encode_doc}, + {"close", + (PyCFunction)RangeEncoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeEncoder_close_doc}, + {0} +}; + + +static PyGetSetDef RangeEncoder_getset[] = { + {0} +}; + + +PyTypeObject RangeEncoder_type = { + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeEncoder", /*tp_name*/ + sizeof(RangeEncoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeEncoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeEncoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeEncoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeEncoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeEncoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeEncoder_new, /*tp_new*/ +}; + +static PyMethodDef RangeDecoder_methods[] = { + {"decode", + (PyCFunction)RangeDecoder_decode, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_decode_doc}, + {"close", + (PyCFunction)RangeDecoder_close, + METH_VARARGS | METH_KEYWORDS, + RangeDecoder_close_doc}, + {0} +}; + + +static PyGetSetDef RangeDecoder_getset[] = { + {0} +}; + + +PyTypeObject RangeDecoder_type = { + PyVarObject_HEAD_INIT(0, 0) + "range_coder.RangeDecoder", /*tp_name*/ + sizeof(RangeDecoderObject), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)RangeDecoder_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequdece*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT, /*tp_flags*/ + RangeDecoder_doc, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + RangeDecoder_methods, /*tp_methods*/ + 0, /*tp_members*/ + RangeDecoder_getset, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + 0, /*tp_dictoffset*/ + (initproc)RangeDecoder_init, /*tp_init*/ + 0, /*tp_alloc*/ + RangeDecoder_new, /*tp_new*/ +}; + +#if PY_MAJOR_VERSION >= 3 +static PyModuleDef range_coder_module = { + PyModuleDef_HEAD_INIT, + "_range_coder", + "A fast implementation of a range encoder and decoder.", + -1, 0, 0, 0, 0, 0 +}; +#endif + + +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC PyInit__range_coder() { + // create module object + PyObject* module = PyModule_Create(&range_coder_module); +#define RETVAL 0; +#else +PyMODINIT_FUNC init_range_coder() { + PyObject* module = Py_InitModule3( + "_range_coder", 0, "A fast implementation of a range encoder and decoder."); +#define RETVAL void(); +#endif + + if(!module) + return RETVAL; + + // initialize types + if(PyType_Ready(&RangeEncoder_type) < 0) + return RETVAL; + if(PyType_Ready(&RangeDecoder_type) < 0) + return RETVAL; + + // add types to module + Py_INCREF(&RangeEncoder_type); + PyModule_AddObject(module, "RangeEncoder", reinterpret_cast(&RangeEncoder_type)); + Py_INCREF(&RangeDecoder_type); + PyModule_AddObject(module, "RangeDecoder", reinterpret_cast(&RangeDecoder_type)); + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} diff --git a/python/src/range_coder_interface.cpp b/python/src/range_coder_interface.cpp new file mode 100644 index 0000000..29ec626 --- /dev/null +++ b/python/src/range_coder_interface.cpp @@ -0,0 +1,345 @@ +#include "range_coder_interface.h" +#include + +PyObject* RangeEncoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + + if (self) + reinterpret_cast(self)->encoder = 0; + + return self; +} + + +PyObject* RangeDecoder_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { + PyObject* self = type->tp_alloc(type, 0); + + if (self) + reinterpret_cast(self)->decoder = 0; + + return self; +} + + +const char* RangeEncoder_doc = "A fast implementation of a range encoder."; + +int RangeEncoder_init(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; + + self->fout = new std::ofstream(filepath, std::ios::out | std::ios::binary); + self->iter = new OutputIterator(*(self->fout)); + self->encoder = new rc_encoder_t(*(self->iter)); + + return 0; +} + + +const char* RangeDecoder_doc = "A fast implementation of a range decoder."; + +int RangeDecoder_init(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"filepath", 0}; + const char* filepath = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", const_cast(kwlist), &filepath)) + return -1; + + self->fin = new std::ifstream(filepath, std::ios::in | std::ios::binary); + self->begin = new InputIterator(*(self->fin)); + self->end = new InputIterator(); + self->decoder = new rc_decoder_t(*(self->begin), *(self->end)); + + return 0; +} + + +void RangeEncoder_dealloc(RangeEncoderObject* self) { + if (self->encoder) { + // flush buffer + self->encoder->final(); + + delete self->encoder; + delete self->iter; + delete self->fout; + } + + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + + +void RangeDecoder_dealloc(RangeDecoderObject* self) { + if (self->decoder) { + delete self->decoder; + delete self->begin; + delete self->end; + delete self->fin; + } + + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + + +const char* RangeEncoder_encode_doc = + "encode(data, cumFreq)\n" + "\n" + "Encodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "data : list[int]\n" + " A list of positive integers representing indices into cumulative frequency table\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies"; + +PyObject* RangeEncoder_encode(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"data", "cumFreq", 0}; + + PyObject* data; + PyObject* cumFreq; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast(kwlist), &data, &cumFreq)) + return 0; + + if (!self->fout->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } + + if (!PyList_Check(data)) { + PyErr_SetString(PyExc_TypeError, "`data` needs to be a list."); + return 0; + } + + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list."); + return 0; + } + + // load cumulative frequency table + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } + + unsigned long* cumFreqArr = new unsigned long[cumFreqLen]; + + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i)); + + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and non-decreasing."); + + if (PyErr_Occurred()) { + delete[] cumFreqArr; + return 0; + } + } + + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + delete[] cumFreqArr; + return 0; + } + + if(cumFreqArr[cumFreqLen - 1] > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } + + // load data + Py_ssize_t dataLen = PyList_Size(data); + Py_ssize_t* dataArr = new Py_ssize_t[dataLen]; + + for (Py_ssize_t i = 0; i < dataLen; ++i) { + #if PY_MAJOR_VERSION >= 3 + dataArr[i] = PyLong_AsSsize_t(PyList_GetItem(data, i)); + #else + dataArr[i] = PyInt_AsSsize_t(PyList_GetItem(data, i)); + #endif + + if (!PyErr_Occurred() and dataArr[i] > cumFreqLen - 2) + PyErr_SetString(PyExc_ValueError, "An entry in `data` is too large or `cumFreq` is too short."); + + if (PyErr_Occurred()) + { + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + } + + bool positive = true; + for (int i = 0; i < cumFreqLen - 1; ++i) { + if (cumFreqArr[i] == cumFreqArr[i + 1]) { + positive = false; + break; + } + } + + // encode data + if (positive) { + // no extra checks necessary + for (int i = 0; i < dataLen; ++i) { + self->encoder->encode( + cumFreqArr[dataArr[i]], + cumFreqArr[dataArr[i] + 1], + cumFreqArr[cumFreqLen - 1]); + } + } else { + for (int i = 0; i < dataLen; ++i) { + unsigned long start = cumFreqArr[dataArr[i]]; + unsigned long end = cumFreqArr[dataArr[i] + 1]; + + if (start == end) { + PyErr_SetString(PyExc_ValueError, "Cannot encode symbol due to zero frequency."); + delete[] cumFreqArr; + delete[] dataArr; + return 0; + } + + self->encoder->encode(start, end, cumFreqArr[cumFreqLen - 1]); + } + } + + delete[] cumFreqArr; + delete[] dataArr; + + Py_INCREF(Py_None); + return Py_None; +} + + +const char* RangeDecoder_decode_doc = + "decode(size, cumFreq)\n" + "\n" + "Decodes a list of indices using the given cumulative frequency table.\n" + "\n" + "The length of the frequency table should be the number of possible symbols plus one.\n" + "\n" + "Parameters\n" + "----------\n" + "size : int\n" + " Number of symbols to decode\n" + "\n" + "cumFreq : list[int]\n" + " List of increasing positive integers representing cumulative frequencies\n" + "\n" + "Returns\n" + "-------\n" + "list[int]\n" + " List of decoded indices"; + +PyObject* RangeDecoder_decode(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + const char* kwlist[] = {"size", "cumFreq", 0}; + + Py_ssize_t size; + PyObject* cumFreq; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "nO", const_cast(kwlist), &size, &cumFreq)) + return 0; + + if (!self->fin->is_open()) { + PyErr_SetString(PyExc_RuntimeError, "File closed."); + return 0; + } + + if (!PyList_Check(cumFreq)) { + PyErr_SetString(PyExc_TypeError, "`cumFreq` needs to be a list of frequencies."); + return 0; + } + + Py_ssize_t cumFreqLen = PyList_Size(cumFreq); + + if (cumFreqLen < 2) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should have at least 2 entries (1 symbol)."); + return 0; + } + + if (cumFreqLen > MAX_NUM_SYMBOLS + 1) { + PyErr_SetString(PyExc_ValueError, "Number of symbols can be at most MAX_NUM_SYMBOLS."); + return 0; + } + + unsigned long resolution = PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, cumFreqLen - 1)); + + if (PyErr_Occurred()) + return 0; + + if(resolution > std::numeric_limits::max()) { + PyErr_SetString(PyExc_OverflowError, "Maximal allowable resolution of frequency table exceeded."); + return 0; + } + + // load cumulative frequency table + SearchType::freq_type cumFreqArr[MAX_NUM_SYMBOLS + 1]; + + for (Py_ssize_t i = 0; i < cumFreqLen; ++i) { + cumFreqArr[i] = static_cast( + PyLong_AsUnsignedLong(PyList_GetItem(cumFreq, i))); + + if (!PyErr_Occurred() and i > 0 and cumFreqArr[i - 1] > cumFreqArr[i]) + PyErr_SetString(PyExc_ValueError, "Entries in `cumFreq` have to be non-negative and increasing."); + + if (PyErr_Occurred()) + return 0; + } + + if (cumFreqArr[0] != 0) { + PyErr_SetString(PyExc_ValueError, "`cumFreq` should start with 0."); + return 0; + } + + // fill up remaining entries + for (Py_ssize_t i = cumFreqLen; i < MAX_NUM_SYMBOLS + 1; ++i) + cumFreqArr[i] = cumFreqArr[cumFreqLen - 1]; + + // decode data + PyObject* data = PyList_New(size); + for (Py_ssize_t i = 0; i < size; ++i) { + rc_type_t::uint index = self->decoder->decode(cumFreqArr[MAX_NUM_SYMBOLS], cumFreqArr); + + PyList_SetItem(data, i, PyLong_FromUnsignedLong(index)); + + if (PyErr_Occurred()) { + Py_DECREF(data); + return 0; + } + } + + return data; +} + + +const char* RangeEncoder_close_doc = + "close()\n" + "\n" + "Write remaining bits in buffer and close file."; + +PyObject* RangeEncoder_close(RangeEncoderObject* self, PyObject* args, PyObject* kwds) { + self->encoder->final(); + self->fout->close(); + + Py_INCREF(Py_None); + return Py_None; +} + + +const char* RangeDecoder_close_doc = + "close()\n" + "\n" + "Close file."; + +PyObject* RangeDecoder_close(RangeDecoderObject* self, PyObject* args, PyObject* kwds) { + self->fin->close(); + + Py_INCREF(Py_None); + return Py_None; +} diff --git a/python/src/range_coder_interface.h b/python/src/range_coder_interface.h new file mode 100644 index 0000000..9957d7e --- /dev/null +++ b/python/src/range_coder_interface.h @@ -0,0 +1,52 @@ +#ifndef RANGE_CODER_INTERFACE_H +#define RANGE_CODER_INTERFACE_H + +#include +#include +#include +#include "range_coder.hpp" + +#define MAX_NUM_SYMBOLS 1024 + +extern const char* RangeEncoder_doc; +extern const char* RangeEncoder_encode_doc; +extern const char* RangeEncoder_close_doc; +extern const char* RangeDecoder_doc; +extern const char* RangeDecoder_decode_doc; +extern const char* RangeDecoder_close_doc; + +typedef std::ostream_iterator OutputIterator; +typedef std::istreambuf_iterator InputIterator; +typedef rc_decoder_search_t SearchType; + +struct RangeEncoderObject { + PyObject_HEAD + rc_encoder_t* encoder; + OutputIterator* iter; + std::ofstream* fout; +}; + +struct RangeDecoderObject { + PyObject_HEAD + rc_decoder_t* decoder; + InputIterator* begin; + InputIterator* end; + std::ifstream* fin; +}; + +PyObject* RangeEncoder_new(PyTypeObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_new(PyTypeObject*, PyObject*, PyObject*); + +int RangeEncoder_init(RangeEncoderObject*, PyObject*, PyObject*); +int RangeDecoder_init(RangeDecoderObject*, PyObject*, PyObject*); + +void RangeEncoder_dealloc(RangeEncoderObject*); +void RangeDecoder_dealloc(RangeDecoderObject*); + +PyObject* RangeEncoder_encode(RangeEncoderObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_decode(RangeDecoderObject*, PyObject*, PyObject*); + +PyObject* RangeEncoder_close(RangeEncoderObject*, PyObject*, PyObject*); +PyObject* RangeDecoder_close(RangeDecoderObject*, PyObject*, PyObject*); + +#endif diff --git a/python/tests/test_range_coder.py b/python/tests/test_range_coder.py new file mode 100644 index 0000000..51e5597 --- /dev/null +++ b/python/tests/test_range_coder.py @@ -0,0 +1,256 @@ +import os +import random +import sys +from tempfile import mkstemp + +import numpy as np +import pytest + +from range_coder import RangeEncoder, RangeDecoder +from range_coder import prob_to_cum_freq, cum_freq_to_prob + + +def test_range_coder_overflow(): + """ + Cumulative frequencies must fit in an unsigned integer (assumed to be represented by 32 bits). + This test checks that no error is thrown if the frequencies exceed that limit. + """ + + numBytes = 17 + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + prob = [4, 6, 8] + prob = np.asarray(prob, dtype=np.float64) / np.sum(prob) + cumFreq = prob_to_cum_freq(prob, 128) + cumFreq[-1] = 2**32 + + sequence = [2, 2] + data = sequence * numBytes + + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + encoder.encode(data, cumFreq) + encoder.close() + + +def test_range_encoder(): + """ + Tests that RangeEncoder writes the expected bits. + + Tests that writing after closing file throws an exception. + """ + + numBytes = 17 + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + cumFreq = [0, 4, 6, 8] + sequence = [0, 0, 0, 0, 1, 2] + sequenceByte = b'\x0b' + data = sequence * numBytes + + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + with pytest.raises(RuntimeError): + # file is already closed, should raise an exception + encoder.encode(sequence, cumFreq) + + assert os.stat(filepath).st_size == numBytes + + with open(filepath, 'rb') as handle: + # the first 4 bytes are special + handle.read(4) + + for _ in range(numBytes - 4): + assert handle.read(1) == sequenceByte + + encoder = RangeEncoder(filepath) + with pytest.raises(OverflowError): + # cumFreq contains negative frequencies + encoder.encode(data, [-1, 1]) + with pytest.raises(ValueError): + # cumFreq does not start at zero + encoder.encode(data, [1, 2, 3]) + with pytest.raises(ValueError): + # cumFreq too short + encoder.encode(data, [0, 1]) + with pytest.raises(ValueError): + # symbols with zero probability cannot be encoded + encoder.encode(data, [0, 8, 8, 8]) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, []) + with pytest.raises(ValueError): + # invalid frequency table + encoder.encode(data, [0]) + encoder.close() + + os.remove(filepath) + + +def test_range_decoder(): + """ + Tests whether RangeDecoder reproduces symbols encoded by RangeEncoder. + """ + + random.seed(0) + + filepath = mkstemp()[1] + + # encoding one sequence should require 1 byte + cumFreq0 = [0, 4, 6, 8] + cumFreq1 = [0, 2, 5, 7, 10, 14] + data0 = [random.randint(0, len(cumFreq0) - 2) for _ in range(10)] + data1 = [random.randint(0, len(cumFreq1) - 2) for _ in range(17)] + + encoder = RangeEncoder(filepath) + encoder.encode(data0, cumFreq0) + encoder.encode(data1, cumFreq1) + encoder.close() + + decoder = RangeDecoder(filepath) + dataRec0 = decoder.decode(len(data0), cumFreq0) + dataRec1 = decoder.decode(len(data1), cumFreq1) + decoder.close() + + # encoded and decoded data should be the same + assert data0 == dataRec0 + assert data1 == dataRec1 + + # make sure reference counting is implemented correctly (call to getrefcount increases it by 1) + assert sys.getrefcount(dataRec0) == 2 + assert sys.getrefcount(dataRec1) == 2 + + decoder = RangeDecoder(filepath) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), []) + with pytest.raises(ValueError): + # invalid frequency table + decoder.decode(len(data0), [0]) + assert decoder.decode(0, cumFreq0) == [] + + os.remove(filepath) + + +def test_range_encoder_decoder(): + """ + Additional tests to check whether RangeDecoder reproduces symbols encoded by RangeEncoder. + """ + + random.seed(0) + + filepath = mkstemp()[1] + + for _ in range(20): + numSymbols = np.random.randint(1, 6) + cumFreq = [0] + np.cumsum(np.random.randint(1, 10, size=numSymbols)).tolist() + data = np.random.randint(numSymbols, size=np.random.randint(20)).tolist() + + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + decoder = RangeDecoder(filepath) + dataRec = decoder.decode(len(data), cumFreq) + decoder.close() + + assert data == dataRec + + os.remove(filepath) + + +def test_range_decoder_fuzz(): + """ + Test whether random inputs to the decoder throw any errors. + """ + + random.seed(1) + randomState = np.random.RandomState(827) + + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + + # decode random symbols + decoder = RangeDecoder('/dev/urandom') + decoder.decode(100, cumFreq) + + +def test_range_encoder_fuzz(): + """ + Test whether random inputs to the encoder throw any errors. + """ + + random.seed(2) + randomState = np.random.RandomState(111) + + filepath = mkstemp()[1] + + for _ in range(10): + # generate random frequency table + numSymbols = random.randint(1, 20) + maxFreq = random.randint(2, 100) + cumFreq = np.cumsum(randomState.randint(1, maxFreq, size=numSymbols)) + cumFreq = [0] + [int(i) for i in cumFreq] # convert numpy.int64 to int + + # encode random symbols + dataLen = randomState.randint(0, 10) + data = [random.randint(0, numSymbols - 1) for _ in range(dataLen)] + encoder = RangeEncoder(filepath) + encoder.encode(data, cumFreq) + encoder.close() + + os.remove(filepath) + + +def test_prob_to_cum_freq(): + """ + Tests whether prob_to_cum_freq produces a table with the expected number + of entries, number of samples, and that non-zero probabilities are + represented by non-zero increases in frequency. + + Tests that cum_freq_to_prob is normalized and consistent with prob_to_cum_freq. + """ + + randomState = np.random.RandomState(190) + resolution = 1024 + + p0 = randomState.dirichlet([.1] * 50) + cumFreq0 = prob_to_cum_freq(p0, resolution) + p1 = cum_freq_to_prob(cumFreq0) + cumFreq1 = prob_to_cum_freq(p1, resolution) + + # number of hypothetical samples should correspond to resolution + assert cumFreq0[-1] == resolution + assert len(cumFreq0) == len(p0) + 1 + + # non-zero probabilities should have non-zero frequencies + assert np.all(np.diff(cumFreq0)[p0 > 0.] > 0) + + # probabilities should be normalized. + assert np.isclose(np.sum(p1), 1.) + + # while the probabilities might change, frequencies should not + assert cumFreq0 == cumFreq1 + + +def test_prob_to_cum_freq_zero_prob(): + """ + Tests whether prob_to_cum_freq handles zero probabilities as expected. + """ + + prob1 = [0.5, 0.25, 0.25] + cumFreq1 = prob_to_cum_freq(prob1, resolution=8) + + prob0 = [0.5, 0., 0.25, 0.25, 0., 0.] + cumFreq0 = prob_to_cum_freq(prob0, resolution=8) + + # removing entries corresponding to zeros + assert [cumFreq0[0]] + [cumFreq0[i + 1] for i, p in enumerate(prob0) if p > 0.] == cumFreq1 diff --git a/range_coder.hpp b/range_coder.hpp index a617e91..c72550f 100644 --- a/range_coder.hpp +++ b/range_coder.hpp @@ -1,11 +1,11 @@ /* - * Copyrgght (c) 2006, Daisuke Okanohara + * Copyright (c) 2006, Daisuke Okanohara * Copyright (c) 2008-2010, Cybozu Labs, Inc. * All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: - * + * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, @@ -14,7 +14,7 @@ * * Neither the name of the copyright holders nor the names of its * contributors may be used to endorse or promote products derived from this * software without specific prior written permission. - * + * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE @@ -31,178 +31,197 @@ #ifndef __RANGE_CODER_HPP__ #define __RANGE_CODER_HPP__ +#include + #ifdef RANGE_CODER_USE_SSE #include #endif struct rc_type_t { - enum { - TOP = 1U << 24, - TOPMASK = TOP - 1, - }; - typedef unsigned int uint; - typedef unsigned char byte; + enum { + TOP = 1U << 24, + TOPMASK = TOP - 1, + }; + typedef unsigned int uint; + typedef unsigned char byte; }; template class rc_encoder_t : public rc_type_t { -public: - rc_encoder_t(const Iter &i) : iter(i) { - L = 0; - R = 0xFFFFFFFF; - buffer = 0; - carryN = 0; - counter = 0; - start = true; - } - void encode(const uint low, const uint high, const uint total) { - uint r = R / total; - if (high < total) { - R = r * (high-low); - } else { - R -= r * low; - } - uint newL = L + r*low; - if (newL < L) { - //overflow occured (newL >= 2^32) - //buffer FF FF .. FF -> buffer+1 00 00 .. 00 - buffer++; - for (;carryN > 0; carryN--) { - *iter++ = buffer; - buffer = 0; - } - } - L = newL; - while (R < TOP) { - const byte newBuffer = (L >> 24); - if (start) { - buffer = newBuffer; - start = false; - } else if (newBuffer == 0xFF) { - carryN++; - } else { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; - } - buffer = newBuffer; - } - L <<= 8; - R <<= 8; - } - counter++; - } - void final() { - *iter++ = buffer; - for (; carryN != 0; carryN--) { - *iter++ = 0xFF; - } - uint t = L + R; - while (1) { - uint t8 = t >> 24, l8 = L >> 24; - *iter++ = l8; - if (t8 != l8) { - break; - } - t <<= 8; - L <<= 8; - } - } -private: - uint R; - uint L; - bool start; - byte buffer; - uint carryN; - Iter iter; - uint counter; + public: + rc_encoder_t(const Iter &i) : iter(i) { + L = 0; + R = 0xFFFFFFFF; + buffer = 0; + carryN = 0; + counter = 0; + start = true; + } + + void encode(const uint low, const uint high, const uint total) { + // encode symbol by adjusting range + uint r = R / total; + if (high < total) { + R = r * (high - low); + } else { + R -= r * low; + } + uint newL = L + r * low; + + if (newL < L) { + // overflow occured (newL >= 2^32) + // buffer FF FF .. FF -> buffer+1 00 00 .. 00 + buffer++; + for (; carryN > 0; carryN--) { + *iter++ = buffer; + buffer = 0; + } + } + + L = newL; + while (R < TOP) { + const byte newBuffer = (L >> 24); + if (start) { + buffer = newBuffer; + start = false; + } else if (newBuffer == 0xFF) { + carryN++; + } else { + // write left-most byte to file + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } + buffer = newBuffer; + } + L <<= 8; + R <<= 8; + } + + counter++; + } + + void final() { + if (!start) { + *iter++ = buffer; + for (; carryN != 0; carryN--) { + *iter++ = 0xFF; + } + } + + uint t = L + R; + while (1) { + uint t8 = t >> 24, l8 = L >> 24; + *iter++ = l8; + if (t8 != l8) { + break; + } + t <<= 8; + L <<= 8; + } + } + + private: + uint R; + uint L; + bool start; + byte buffer; + uint carryN; + Iter iter; + uint counter; }; template struct rc_decoder_search_traits_t : public rc_type_t { - typedef FreqType freq_type; - enum { - N = _N, - BASE = _BASE - }; + typedef FreqType freq_type; + enum { + N = _N, + BASE = _BASE + }; }; template struct rc_decoder_search_t : public rc_decoder_search_traits_t { - static rc_type_t::uint get_index(const FreqType *freq, FreqType pos) { - rc_type_t::uint left = 0; - rc_type_t::uint right = _N - 1; - while(left < right) { - rc_type_t::uint mid = (left+right)/2; - if (freq[mid+1] <= pos) left = mid+1; - else right = mid; - } - return left; - } + static rc_type_t::uint get_index(const FreqType *freq, FreqType pos) { + rc_type_t::uint left = 0; + rc_type_t::uint right = _N - 1; + while (left < right) { + rc_type_t::uint mid = (left + right) / 2; + if (freq[mid+1] <= pos) { + left = mid + 1; + } else { + right = mid; + } + } + return left; + } }; #ifdef RANGE_CODER_USE_SSE - template struct rc_decoder_search_t : public rc_decoder_search_traits_t { - static rc_type_t::uint get_index(const short *freq, short pos) { - __m128i v = _mm_set1_epi16(pos); - unsigned i, mask = 0; - for (i = 0; i < 256; i += 16) { - __m128i x = *reinterpret_cast(freq + i); - __m128i y = *reinterpret_cast(freq + i + 8); - __m128i a = _mm_cmplt_epi16(v, x); - __m128i b = _mm_cmplt_epi16(v, y); - mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a); - if (mask) { - return i + (__builtin_ctz(mask) >> 1) - 1; - } - } - return 255; - } + static rc_type_t::uint get_index(const short *freq, short pos) { + __m128i v = _mm_set1_epi16(pos); + unsigned i, mask = 0; + for (i = 0; i < 256; i += 16) { + __m128i x = *reinterpret_cast(freq + i); + __m128i y = *reinterpret_cast(freq + i + 8); + __m128i a = _mm_cmplt_epi16(v, x); + __m128i b = _mm_cmplt_epi16(v, y); + mask = (_mm_movemask_epi8(b) << 16) | _mm_movemask_epi8(a); + if (mask) { + return i + (__builtin_ctz(mask) >> 1) - 1; + } + } + return 255; + } }; - #endif template class rc_decoder_t : public rc_type_t { -public: - typedef SearchType search_type; - typedef typename search_type::freq_type freq_type; - static const unsigned N = search_type::N; - rc_decoder_t(const Iterator& _i, const Iterator _e) : iter(_i), iter_end(_e) { - R = 0xFFFFFFFF; - D = 0; - for (int i = 0; i < 4; i++) { - D = (D << 8) | next(); - } - } - uint decode(const uint total, const freq_type* cumFreq) { - const uint r = R / total; - const int targetPos = std::min(total-1, D / r); - - //find target s.t. cumFreq[target] <= targetPos < cumFreq[target+1] - const uint target = - search_type::get_index(cumFreq, targetPos + search_type::BASE); - const uint low = cumFreq[target] - search_type::BASE; - const uint high = cumFreq[target+1] - search_type::BASE; - - D -= r * low; - if (high != total) { - R = r * (high-low); - } else { - R -= r * low; - } - - while (R < TOP) { - R <<= 8; - D = (D << 8) | next(); - } - - return target; - } - byte next() { - return iter != iter_end ? (byte)*iter++ : 0xff; - } -private: - uint R; - uint D; - Iterator iter, iter_end; + public: + typedef SearchType search_type; + typedef typename search_type::freq_type freq_type; + static const unsigned N = search_type::N; + + rc_decoder_t(const Iterator& _i, const Iterator _e) : iter(_i), iter_end(_e) { + R = 0xFFFFFFFF; + D = 0; + + // read first four bytes from file + for (int i = 0; i < 4; i++) { + D = (D << 8) | next(); + } + } + + uint decode(const uint total, const freq_type* cumFreq) { + const uint r = R / total; + const int targetPos = std::min(total - 1, D / r); + + // find target s.t. cumFreq[target] <= targetPos < cumFreq[target + 1] + const uint target = search_type::get_index(cumFreq, targetPos + search_type::BASE); + const uint low = cumFreq[target] - search_type::BASE; + const uint high = cumFreq[target + 1] - search_type::BASE; + + D -= r * low; + if (high != total) { + R = r * (high - low); + } else { + R -= r * low; + } + + while (R < TOP) { + R <<= 8; + D = (D << 8) | next(); + } + + return target; + } + + byte next() { + return iter != iter_end ? (byte)*iter++ : 0xff; + } + + private: + uint R; + uint D; + Iterator iter, iter_end; }; #endif diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..c44f2f7 --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, Extension + +setup( + name="range_coder", + version="1.1.2", + description="A fast implementation of a range coder", + packages=["range_coder"], + package_dir={"range_coder": "python"}, + license="BSD", + ext_modules=[ + Extension( + "range_coder._range_coder", + language="c++", + include_dirs=["."], + sources=[ + "python/src/range_coder_interface.cpp", + "python/src/module.cpp", + ] + ) + ] +)