diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py index 1bd37cfedc6..ee223014061 100644 --- a/ignite/metrics/js_divergence.py +++ b/ignite/metrics/js_divergence.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from packaging.version import Version from ignite.exceptions import NotComputableError from ignite.metrics.kl_divergence import KLDivergence @@ -7,6 +8,8 @@ __all__ = ["JSDivergence"] +TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0") + class JSDivergence(KLDivergence): r"""Calculates the mean of `Jensen-Shannon (JS) divergence @@ -71,14 +74,24 @@ class JSDivergence(KLDivergence): """ def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: - m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2 + y_pred_prob = F.softmax(y_pred, dim=1) + y_prob = F.softmax(y, dim=1) + m_prob = (y_pred_prob + y_prob) / 2 m_log = m_prob.log() - y_pred = F.log_softmax(y_pred, dim=1) - y = F.log_softmax(y, dim=1) - self._sum_of_kl += ( - F.kl_div(m_log, y_pred, log_target=True, reduction="sum") - + F.kl_div(m_log, y, log_target=True, reduction="sum") - ).to(self._device) + + if TORCH_VERSION_GE_160: + # log_target option can be used from 1.6.0 + y_pred_log = F.log_softmax(y_pred, dim=1) + y_log = F.log_softmax(y, dim=1) + self._sum_of_kl += ( + F.kl_div(m_log, y_pred_log, log_target=True, reduction="sum") + + F.kl_div(m_log, y_log, log_target=True, reduction="sum") + ).to(self._device) + else: + # y_pred and y are expected to be probabilities + self._sum_of_kl += ( + F.kl_div(m_log, y_pred_prob, reduction="sum") + F.kl_div(m_log, y_prob, reduction="sum") + ).to(self._device) @sync_all_reduce("_sum_of_kl", "_num_examples") def compute(self) -> float: diff --git a/ignite/metrics/kl_divergence.py b/ignite/metrics/kl_divergence.py index 99f6cbcfa84..93f6d5a8528 100644 --- a/ignite/metrics/kl_divergence.py +++ b/ignite/metrics/kl_divergence.py @@ -2,12 +2,15 @@ import torch import torch.nn.functional as F +from packaging.version import Version from ignite.exceptions import NotComputableError from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["KLDivergence"] +TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0") + class KLDivergence(Metric): r"""Calculates the mean of `Kullback-Leibler (KL) divergence @@ -91,8 +94,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None: def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: y_pred = F.log_softmax(y_pred, dim=1) - y = F.log_softmax(y, dim=1) - kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") + + if TORCH_VERSION_GE_160: + # log_target option can be used from 1.6.0 + y = F.log_softmax(y, dim=1) + kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") + else: + # y is expected to be a probability tensor + y = F.softmax(y, dim=1) + kl_sum = F.kl_div(y_pred, y, reduction="sum") + self._sum_of_kl += kl_sum.to(self._device) @sync_all_reduce("_sum_of_kl", "_num_examples")