From 565e8be07b29f37dc02096a78857e0fa3930d314 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 27 Apr 2024 05:04:53 +0900 Subject: [PATCH] Fix error in old PyTorch for KL and JS divergence (#3236) * add KLDivergence metric * add JSDivergence * fix variable name * update docstring for JSDivergence * Update ignite/metrics/js_divergence.py Co-authored-by: vfdev * Update ignite/metrics/kl_divergence.py Co-authored-by: vfdev * swap ground truth and prediction * swap the definitions of p and q * fix error in old pytorch * switch to use log_target option by version * check pytorch version in the global space in advance --------- Co-authored-by: vfdev --- ignite/metrics/js_divergence.py | 27 ++++++++++++++++++++------- ignite/metrics/kl_divergence.py | 15 +++++++++++++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py index 1bd37cfedc6..ee223014061 100644 --- a/ignite/metrics/js_divergence.py +++ b/ignite/metrics/js_divergence.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from packaging.version import Version from ignite.exceptions import NotComputableError from ignite.metrics.kl_divergence import KLDivergence @@ -7,6 +8,8 @@ __all__ = ["JSDivergence"] +TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0") + class JSDivergence(KLDivergence): r"""Calculates the mean of `Jensen-Shannon (JS) divergence @@ -71,14 +74,24 @@ class JSDivergence(KLDivergence): """ def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: - m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2 + y_pred_prob = F.softmax(y_pred, dim=1) + y_prob = F.softmax(y, dim=1) + m_prob = (y_pred_prob + y_prob) / 2 m_log = m_prob.log() - y_pred = F.log_softmax(y_pred, dim=1) - y = F.log_softmax(y, dim=1) - self._sum_of_kl += ( - F.kl_div(m_log, y_pred, log_target=True, reduction="sum") - + F.kl_div(m_log, y, log_target=True, reduction="sum") - ).to(self._device) + + if TORCH_VERSION_GE_160: + # log_target option can be used from 1.6.0 + y_pred_log = F.log_softmax(y_pred, dim=1) + y_log = F.log_softmax(y, dim=1) + self._sum_of_kl += ( + F.kl_div(m_log, y_pred_log, log_target=True, reduction="sum") + + F.kl_div(m_log, y_log, log_target=True, reduction="sum") + ).to(self._device) + else: + # y_pred and y are expected to be probabilities + self._sum_of_kl += ( + F.kl_div(m_log, y_pred_prob, reduction="sum") + F.kl_div(m_log, y_prob, reduction="sum") + ).to(self._device) @sync_all_reduce("_sum_of_kl", "_num_examples") def compute(self) -> float: diff --git a/ignite/metrics/kl_divergence.py b/ignite/metrics/kl_divergence.py index 99f6cbcfa84..93f6d5a8528 100644 --- a/ignite/metrics/kl_divergence.py +++ b/ignite/metrics/kl_divergence.py @@ -2,12 +2,15 @@ import torch import torch.nn.functional as F +from packaging.version import Version from ignite.exceptions import NotComputableError from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce __all__ = ["KLDivergence"] +TORCH_VERSION_GE_160 = Version(torch.__version__) >= Version("1.6.0") + class KLDivergence(Metric): r"""Calculates the mean of `Kullback-Leibler (KL) divergence @@ -91,8 +94,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None: def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: y_pred = F.log_softmax(y_pred, dim=1) - y = F.log_softmax(y, dim=1) - kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") + + if TORCH_VERSION_GE_160: + # log_target option can be used from 1.6.0 + y = F.log_softmax(y, dim=1) + kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") + else: + # y is expected to be a probability tensor + y = F.softmax(y, dim=1) + kl_sum = F.kl_div(y_pred, y, reduction="sum") + self._sum_of_kl += kl_sum.to(self._device) @sync_all_reduce("_sum_of_kl", "_num_examples")