Skip to content

Commit

Permalink
Resolved log(0) error in KL divergence Issue#12233
Browse files Browse the repository at this point in the history
  • Loading branch information
anscian committed Feb 4, 2025
1 parent 6c92c5a commit 1f98734
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion machine_learning/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
if len(y_true) != len(y_pred):
raise ValueError("Input arrays must have the same length.")

kl_loss = y_true * np.log(y_true / y_pred)
kl_loss = np.concatenate((y_true[None, :], y_pred[None, :])) # true probs in first row and predicted in second

Check failure on line 662 in machine_learning/loss_functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/loss_functions.py:662:89: E501 Line too long (114 > 88)
kl_loss = kl_loss[:, np.any(kl_loss == 0, axis=0) == False] # Filtered zero probabilities from both probability arrays

Check failure on line 663 in machine_learning/loss_functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E712)

machine_learning/loss_functions.py:663:26: E712 Avoid equality comparisons to `False`; use `if not np.any(kl_loss == 0, axis=0):` for false checks

Check failure on line 663 in machine_learning/loss_functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/loss_functions.py:663:89: E501 Line too long (122 > 88)
kl_loss = kl_loss[0] * np.log(kl_loss[0] / kl_loss[1]) # Calculating safely now
return np.sum(kl_loss)


Expand Down

0 comments on commit 1f98734

Please # to comment.