|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import mmengine |
| 5 | + |
| 6 | + |
| 7 | +def parse_args(): |
| 8 | + parser = argparse.ArgumentParser( |
| 9 | + description='Analyse summary.yml generated by benchmark test') |
| 10 | + parser.add_argument('file_path', help='Summary.yml path') |
| 11 | + args = parser.parse_args() |
| 12 | + return args |
| 13 | + |
| 14 | + |
| 15 | +metric_mapping = { |
| 16 | + 'Top 1 Accuracy': 'accuracy/top1', |
| 17 | + 'Top 5 Accuracy': 'accuracy/top5', |
| 18 | + 'box AP': 'coco/bbox_mAP', |
| 19 | + 'mIoU': 'mIoU' |
| 20 | +} |
| 21 | + |
| 22 | + |
| 23 | +def compare_metric(result, metric): |
| 24 | + expect_val = result['expect'][metric] |
| 25 | + actual_val = result['actual'].get(metric_mapping[metric], None) |
| 26 | + if actual_val is None: |
| 27 | + return None, None |
| 28 | + if metric == 'box AP': |
| 29 | + actual_val *= 100 |
| 30 | + decimal_bit = len(str(expect_val).split('.')[-1]) |
| 31 | + actual_val = round(actual_val, decimal_bit) |
| 32 | + error = round(actual_val - expect_val, decimal_bit) |
| 33 | + error_percent = round(abs(error) * 100 / expect_val, 3) |
| 34 | + return error, error_percent |
| 35 | + |
| 36 | + |
| 37 | +def main(): |
| 38 | + args = parse_args() |
| 39 | + file_path = args.file_path |
| 40 | + results = mmengine.load(file_path, 'yml') |
| 41 | + miss_models = dict() |
| 42 | + sort_by_error = dict() |
| 43 | + for k, v in results.items(): |
| 44 | + valid_keys = v['expect'].keys() |
| 45 | + compare_res = dict() |
| 46 | + for m in valid_keys: |
| 47 | + error, error_percent = compare_metric(v, m) |
| 48 | + if error is None: |
| 49 | + continue |
| 50 | + compare_res[m] = {'error': error, 'error_percent': error_percent} |
| 51 | + if error != 0: |
| 52 | + miss_models[k] = compare_res |
| 53 | + sort_by_error[k] = error |
| 54 | + sort_by_error = sorted( |
| 55 | + sort_by_error.items(), key=lambda x: abs(x[1]), reverse=True) |
| 56 | + miss_models_sort = dict() |
| 57 | + miss_models_sort['total error models'] = len(sort_by_error) |
| 58 | + for k_v in sort_by_error: |
| 59 | + index = k_v[0] |
| 60 | + miss_models_sort[index] = miss_models[index] |
| 61 | + save_path = os.path.join(os.path.dirname(file_path), 'summary_error.yml') |
| 62 | + mmengine.fileio.dump(miss_models_sort, save_path, sort_keys=False) |
| 63 | + print(f'Summary analysis result saved in {save_path}') |
| 64 | + |
| 65 | + |
| 66 | +if __name__ == '__main__': |
| 67 | + main() |
0 commit comments