diff --git a/model_quantization.py b/model_quantization.py index 2ae1b80..ecf60e9 100755 --- a/model_quantization.py +++ b/model_quantization.py @@ -64,19 +64,23 @@ def quantize_array_lbda(A, lbda): def quantize_array_target(A, target_err): low = 1 high = 128 + mid = low A_norms = np.sqrt(np.sum(A**2, axis=-1)) - while high - low > 1: - mid = (high + low) / 2 + while low < high: + mid = low + (high - low) / 2 quant_A, dequant_A = quantize_array(A, mid) mean_err = np.mean(np.sqrt(np.sum((dequant_A - A)**2, axis=-1)) / A_norms) logging.info("Binary search: q=%d, err=%.3f", mid, mean_err) if mean_err > target_err: - low = mid + low = mid + 1 else: high = mid - + mid = low + quant_A, dequant_A = quantize_array(A, mid) + mean_err = np.mean(np.sqrt(np.sum((dequant_A - A)**2, axis=-1)) / A_norms) + logging.info("Result: q=%d, err=%.3f", mid, mean_err) return mid, mean_err, quant_A, dequant_A