Skip to content

Commit

Permalink
Fix error in old PyTorch for KL and JS divergence (#3236)
Browse files Browse the repository at this point in the history
* add KLDivergence metric

* add JSDivergence

* fix variable name

* update docstring for JSDivergence

* Update ignite/metrics/js_divergence.py

Co-authored-by: vfdev <vfdev.5@gmail.com>

* Update ignite/metrics/kl_divergence.py

Co-authored-by: vfdev <vfdev.5@gmail.com>

* swap ground truth and prediction

* swap the definitions of p and q

* fix error in old pytorch

* switch to use log_target option by version

* check pytorch version in the global space in advance

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
kzkadc and vfdev-5 authored Apr 26, 2024
1 parent 95c0154 commit 565e8be
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
27 changes: 20 additions & 7 deletions ignite/metrics/js_divergence.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
import torch.nn.functional as F
from packaging.version import Version

from ignite.exceptions import NotComputableError
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.metric import sync_all_reduce

__all__ = ["JSDivergence"]

TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0")


class JSDivergence(KLDivergence):
r"""Calculates the mean of `Jensen-Shannon (JS) divergence
Expand Down Expand Up @@ -71,14 +74,24 @@ class JSDivergence(KLDivergence):
"""

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2
y_pred_prob = F.softmax(y_pred, dim=1)
y_prob = F.softmax(y, dim=1)
m_prob = (y_pred_prob + y_prob) / 2
m_log = m_prob.log()
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
self._sum_of_kl += (
F.kl_div(m_log, y_pred, log_target=True, reduction="sum")
+ F.kl_div(m_log, y, log_target=True, reduction="sum")
).to(self._device)

if TORCH_VERSION_GE_160:
# log_target option can be used from 1.6.0
y_pred_log = F.log_softmax(y_pred, dim=1)
y_log = F.log_softmax(y, dim=1)
self._sum_of_kl += (
F.kl_div(m_log, y_pred_log, log_target=True, reduction="sum")
+ F.kl_div(m_log, y_log, log_target=True, reduction="sum")
).to(self._device)
else:
# y_pred and y are expected to be probabilities
self._sum_of_kl += (
F.kl_div(m_log, y_pred_prob, reduction="sum") + F.kl_div(m_log, y_prob, reduction="sum")
).to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
def compute(self) -> float:
Expand Down
15 changes: 13 additions & 2 deletions ignite/metrics/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import torch
import torch.nn.functional as F
from packaging.version import Version

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["KLDivergence"]

TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0")


class KLDivergence(Metric):
r"""Calculates the mean of `Kullback-Leibler (KL) divergence
Expand Down Expand Up @@ -91,8 +94,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None:

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum")

if TORCH_VERSION_GE_160:
# log_target option can be used from 1.6.0
y = F.log_softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum")
else:
# y is expected to be a probability tensor
y = F.softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, reduction="sum")

self._sum_of_kl += kl_sum.to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
Expand Down

0 comments on commit 565e8be

Please # to comment.