# Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= # Draws diargams to provide an intuitive understanding of the discretization # that result from 16-bit and 8-bit weight quantization. from matplotlib import pyplot as plt import numpy as np def quantize(w, bits): """ Simulate weight quantization. Args: w: (a numpy.ndarray) The weight to be quantized. bits: (int) number of bits used for the quantization: 8 or 16. Returns: A tuple with three elements: w_quantized: the quantized version of w, represented as an uint8- or uint16-type numpy.ndarray. w_min: Minimum value of w, required for dequantization. w_max: Maximum value of w, required for dequantization. """ if bits == 8: dtype = np.uint8 elif bits == 16: dtype = np.uint16 else: raise ValueError('Unsupported bits of quantization: %s' % bits) w_min = np.min(w) w_max = np.max(w) if w_max == w_min: raise ValueError('Cannot perform quantization because w has a range of 0') w_quantized = np.array( np.floor((w - w_min) / (w_max - w_min) * np.power(2, bits)), dtype) return w_quantized, w_min, w_max def dequantize(w_quantized, w_min, w_max): """ Simulate weight de-quantization. Args: w: (a numpy.ndarray) The weight to be quantized. bits: (int) number of bits used for the quantization: 8 or 16. Returns: A tuple with three elements: w_quantized: the quantized version of w, represented as an uint8- or uint16-type numpy.ndarray. w_min: Minimum value of w, required for dequantization. w_max: Maximum value of w, required for dequantization. """ if w_quantized.dtype == np.uint8: bits = 8 elif w_quantized.dtype == np.uint16: bits = 16 else: raise ValueError( 'Unsupported dtype in quantized values: %s' % w_quantized.dtype) return (w_min + w_quantized.astype(np.float64) / np.power(2, bits) * (w_max - w_min)) def main(): # Number of points along the x-axis used to draw the sine wave. n_points = 1e6 xs = np.linspace(-np.pi, np.pi, n_points).astype(np.float64) w = xs w_16bit = dequantize(*quantize(w, 16)) w_8bit = dequantize(*quantize(w, 8)) plot_delta = 1.2e-4 plot_range = range(int(n_points * (0.5 - plot_delta)), int(n_points * (0.5 + plot_delta))) plt.figure(figsize=(20, 6)) plt.subplot(1, 3, 1) plt.plot(xs[plot_range], w[plot_range], '-') plt.title('Original (float32)', {'fontsize': 16}) plt.xlabel('x') plt.subplot(1, 3, 2) plt.plot(xs[plot_range], w_16bit[plot_range], '-') plt.title('16-bit quantization', {'fontsize': 16}) plt.xlabel('x') plt.subplot(1, 3, 3) plt.plot(xs[plot_range], w_8bit[plot_range], '-') plt.title('8-bit quantization', {'fontsize': 16}) plt.xlabel('x') plt.show() if __name__ == '__main__': main()