Skip to content

Commit

Permalink
[Improvement] speed up confusion matrix calculation (#465)
Browse files Browse the repository at this point in the history
* a little bit faster confusion matrix

* add changelog
  • Loading branch information
dreamerlin authored Dec 21, 2020
1 parent 30ff6b2 commit 777546f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Support training and testing for Spatio-Temporal Action Detection ([#351](https://github.com/open-mmlab/mmaction2/pull/351))
- Fix CI due to pip upgrade ([#454](https://github.com/open-mmlab/mmaction2/pull/454))
- Add markdown lint in pre-commit hook ([#255](https://github.com/open-mmlab/mmaction2/pull/225))
- Speed up confusion matrix calculation ([#465](https://github.com/open-mmlab/mmaction2/pull/465))
- Use title case in modelzoo statistics ([#456](https://github.com/open-mmlab/mmaction2/pull/456))
- Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439))

Expand Down
17 changes: 11 additions & 6 deletions mmaction/core/evaluation/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ def confusion_matrix(y_pred, y_real, normalize=None):

label_set = np.unique(np.concatenate((y_pred, y_real)))
num_labels = len(label_set)
label_map = {label: i for i, label in enumerate(label_set)}
confusion_mat = np.zeros((num_labels, num_labels), dtype=np.int64)
for rlabel, plabel in zip(y_real, y_pred):
index_real = label_map[rlabel]
index_pred = label_map[plabel]
confusion_mat[index_real][index_pred] += 1
max_label = label_set[-1]
label_map = np.zeros(max_label + 1, dtype=np.int64)
for i, label in enumerate(label_set):
label_map[label] = i

y_pred_mapped = label_map[y_pred]
y_real_mapped = label_map[y_real]

confusion_mat = np.bincount(
num_labels * y_real_mapped + y_pred_mapped,
minlength=num_labels**2).reshape(num_labels, num_labels)

with np.errstate(all='ignore'):
if normalize == 'true':
Expand Down

0 comments on commit 777546f

Please # to comment.