Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix error in old PyTorch for KL and JS divergence #3236

Merged
merged 13 commits into from
Apr 26, 2024
25 changes: 18 additions & 7 deletions ignite/metrics/js_divergence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from packaging.version import Version

from ignite.exceptions import NotComputableError
from ignite.metrics.kl_divergence import KLDivergence
Expand Down Expand Up @@ -71,14 +72,24 @@ class JSDivergence(KLDivergence):
"""

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2
y_pred_prob = F.softmax(y_pred, dim=1)
y_prob = F.softmax(y, dim=1)
m_prob = (y_pred_prob + y_prob) / 2
m_log = m_prob.log()
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
self._sum_of_kl += (
F.kl_div(m_log, y_pred, log_target=True, reduction="sum")
+ F.kl_div(m_log, y, log_target=True, reduction="sum")
).to(self._device)

if Version(torch.__version__) >= Version("1.6.0"):
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
# log_target option can be used from 1.6.0
y_pred_log = F.log_softmax(y_pred, dim=1)
y_log = F.log_softmax(y, dim=1)
self._sum_of_kl += (
F.kl_div(m_log, y_pred_log, log_target=True, reduction="sum")
+ F.kl_div(m_log, y_log, log_target=True, reduction="sum")
).to(self._device)
else:
# y_pred and y are expected to be probabilities
self._sum_of_kl += (
F.kl_div(m_log, y_pred_prob, reduction="sum") + F.kl_div(m_log, y_prob, reduction="sum")
).to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
def compute(self) -> float:
Expand Down
13 changes: 11 additions & 2 deletions ignite/metrics/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn.functional as F
from packaging.version import Version

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
Expand Down Expand Up @@ -91,8 +92,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None:

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum")

if Version(torch.__version__) >= Version("1.6.0"):
# log_target option can be used from 1.6.0
y = F.log_softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum")
else:
# y is expected to be a probability tensor
y = F.softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, reduction="sum")

self._sum_of_kl += kl_sum.to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
Expand Down
Loading