Skip to content

Commit

Permalink
Fix h-label loss normalization issue w/ exclusive label group of sing…
Browse files Browse the repository at this point in the history
…e label (#2604)

* Fix h-label loss normalization issue w/ exclusive label group with signle label

* Fix non-linear version

---------
Signed-off-by: Songki Choi <songki.choi@intel.com>
  • Loading branch information
goodsong81 authored Nov 6, 2023
1 parent c22c683 commit a4abbed
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ All notable changes to this project will be documented in this file.

- Fix IBLoss enablement with DeiT-Tiny when class incremental training (<https://github.com/openvinotoolkit/training_extensions/pull/2595>)
- Fix mmcls bug not wrapping model in DataParallel on CPUs (<https://github.com/openvinotoolkit/training_extensions/pull/2601>)
- Fix h-label loss normalization issue w/ exclusive label group of singe label (<https://github.com/openvinotoolkit/training_extensions/pull/2604>)

## \[v1.4.3\]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
if num_effective_heads_in_batch > 0:
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward_train(self, cls_score, gt_label, **kwargs):
losses["loss"] += multiclass_loss
num_effective_heads_in_batch += 1

if self.hierarchical_info["num_multiclass_heads"] > 1:
if num_effective_heads_in_batch > 0:
losses["loss"] /= num_effective_heads_in_batch

if self.compute_multilabel_loss:
Expand Down

0 comments on commit a4abbed

Please # to comment.