Skip to content

Commit

Permalink
Fix wrong percentile values returned during calibration (microsoft#10847
Browse files Browse the repository at this point in the history
)

* Use numpy.percentile to get the lookup value.

* Use 1.0 as float value rather than integer.

* Add missing cdf parameter for `np.percentile`.

* Use 100. instead of 1.0

* Remove print.

* Update from @yufenglee
  • Loading branch information
mfuntowicz authored and seddonm1 committed May 15, 2022
1 parent 4a2223d commit 4662bbf
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _create_inference_session(self):

def select_tensors_to_calibrate(self, model):
'''
select all quantization_candidates op type nodes' input/output tensors.
select all quantization_candidates op type nodes' input/output tensors.
returns:
tensors (set): set of tensor name.
value_infos (dict): tensor name to value info.
Expand Down Expand Up @@ -139,7 +139,7 @@ def compute_range(self, data_reader: CalibrationDataReader):


class MinMaxCalibrater(CalibraterBase):
def __init__(self,
def __init__(self,
model,
op_types_to_calibrate=[],
augmented_model_path='augmented_model.onnx',
Expand Down Expand Up @@ -178,7 +178,7 @@ def augment_graph(self):

added_nodes = []
added_outputs = []
tensors, value_infos = self.select_tensors_to_calibrate(model)
tensors, value_infos = self.select_tensors_to_calibrate(model)

for tensor in tensors:

Expand Down Expand Up @@ -233,7 +233,7 @@ def merge_range(self, old_range, new_range):
if not old_range:
return new_range

for key, value in old_range.items():
for key, value in old_range.items():
if self.moving_average:
min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0])
max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1])
Expand All @@ -245,7 +245,7 @@ def merge_range(self, old_range, new_range):
return new_range

def compute_range(self):
'''
'''
Compute the min-max range of tensor
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
'''
Expand Down Expand Up @@ -295,7 +295,7 @@ def compute_range(self):
if self.calibrate_tensors_range:
self.calibrate_tensors_range = self.merge_range(self.calibrate_tensors_range, new_calibrate_tensors_range)
else:
self.calibrate_tensors_range = new_calibrate_tensors_range
self.calibrate_tensors_range = new_calibrate_tensors_range

return self.calibrate_tensors_range

Expand Down Expand Up @@ -344,7 +344,7 @@ def augment_graph(self):

added_nodes = []
added_outputs = []
tensors, value_infos = self.select_tensors_to_calibrate(model)
tensors, value_infos = self.select_tensors_to_calibrate(model)

for tensor in tensors:
added_outputs.append(value_infos[tensor])
Expand All @@ -359,7 +359,7 @@ def clear_collected_data(self):

def collect_data(self, data_reader: CalibrationDataReader):
'''
Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator.
Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator.
'''
while True:
inputs = data_reader.get_next()
Expand Down Expand Up @@ -393,7 +393,7 @@ def collect_data(self, data_reader: CalibrationDataReader):
self.clear_collected_data()

def compute_range(self):
'''
'''
Compute the min-max range of tensor
:return: dictionary mapping: {tensor name: (min value, max value)}
'''
Expand Down Expand Up @@ -457,15 +457,15 @@ class CalibrationDataCollector(metaclass=abc.ABCMeta):
def collect(self, name_to_arr):
"""
Generate informative data based on given data.
name_to_arr : dict
tensor name to NDArray data
name_to_arr : dict
tensor name to NDArray data
"""
raise NotImplementedError

@abc.abstractmethod
def compute_collection_result(self):
"""
Get the optimal result among collection data.
Get the optimal result among collection data.
"""
raise NotImplementedError

Expand Down Expand Up @@ -569,7 +569,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho
else:
old_num_bins = len(old_hist)
old_stride = 2 * old_threshold / old_num_bins
half_increased_bins = int((new_threshold - old_threshold) // old_stride + 1)
half_increased_bins = int((new_threshold - old_threshold) // old_stride + 1)
new_num_bins = old_num_bins + 2 * half_increased_bins
new_threshold = half_increased_bins * old_stride + old_threshold
hist, hist_edges = np.histogram(data_arr, new_num_bins, range=(-new_threshold, new_threshold))
Expand Down Expand Up @@ -655,9 +655,9 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
num_bins = hist.size
zero_bin_index = num_bins // 2
num_half_quantized_bin = num_quantized_bins // 2

kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1)
thresholds = [(0, 0) for i in range(kl_divergence.size)]
thresholds = [(0, 0) for i in range(kl_divergence.size)]

# <------------ num bins ---------------->
# <--- quantized bins ---->
Expand All @@ -674,7 +674,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
# start index end index (end of iteration)

for i in range(num_half_quantized_bin, zero_bin_index + 1, 1):
start_index = zero_bin_index - i
start_index = zero_bin_index - i
end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins

thresholds[i - num_half_quantized_bin] = (float(hist_edges[start_index]), float(hist_edges[end_index]))
Expand All @@ -683,23 +683,23 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):

# reference distribution p
p = sliced_distribution.copy() # a copy of np array
left_outliers_count = sum(hist[:start_index])
left_outliers_count = sum(hist[:start_index])
right_outliers_count = sum(hist[end_index:])
p[0] += left_outliers_count
p[-1] += right_outliers_count

# nonzeros[i] incidates whether p[i] is non-zero
nonzeros = (p != 0).astype(np.int64)
# quantize p.size bins into quantized bins (default 128 bins)

# quantize p.size bins into quantized bins (default 128 bins)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int64)
num_merged_bins = sliced_distribution.size // num_quantized_bins

# merge bins into quantized bins
for index in range(num_quantized_bins):
start = index * num_merged_bins
start = index * num_merged_bins
end = start + num_merged_bins
quantized_bins[index] = sum(sliced_distribution[start:end])
quantized_bins[index] = sum(sliced_distribution[start:end])
quantized_bins[-1] += sum(sliced_distribution[num_quantized_bins * num_merged_bins:])

# in order to compare p and q, we need to make length of q equals to length of p
Expand All @@ -712,7 +712,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
norm = sum(nonzeros[start:end])
if norm != 0:
q[start:end] = float(quantized_bins[index]) / float(norm)

p = smooth_distribution(p)
q = smooth_distribution(q)

Expand All @@ -722,7 +722,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
kl_divergence[i - num_half_quantized_bin] = float('inf')

min_kl_divergence_idx = np.argmin(kl_divergence)
optimal_threshold = thresholds[min_kl_divergence_idx]
optimal_threshold = thresholds[min_kl_divergence_idx]

return optimal_threshold

Expand Down

0 comments on commit 4662bbf

Please # to comment.