Skip to content

Commit

Permalink
feat(hook): add nan_to_num hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 18, 2022
1 parent f831b53 commit 7f4b991
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,16 @@ Optimizer Hooks

register_hook
zero_nan_hook
nan_to_zero_hook
nan_to_zero

Hook
~~~~

.. autofunction:: register_hook
.. autofunction:: zero_nan_hook
.. autofunction:: nan_to_zero_hook
.. autofunction:: nan_to_zero

------

Expand Down
3 changes: 3 additions & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchopt.alias import adam, adamw, rmsprop, sgd
from torchopt.clip import clip_grad_norm
from torchopt.combine import chain
from torchopt.hook import nan_to_num, register_hook
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
Expand Down Expand Up @@ -60,6 +61,8 @@
'rmsprop',
'sgd',
'clip_grad_norm',
'nan_to_num',
'register_hook',
'chain',
'Optimizer',
'SGD',
Expand Down
49 changes: 46 additions & 3 deletions torchopt/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,61 @@
# ==============================================================================
"""Hook utilities."""

from typing import Callable, Optional

import torch

from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation


__all__ = ['zero_nan_hook', 'register_hook']
__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'nan_to_num', 'register_hook']


def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
"""Registers a zero nan hook to replace nan with zero."""
return torch.where(torch.isnan(g), torch.zeros_like(g), g)
"""A zero ``nan`` hook to replace ``nan`` with zero."""
return g.nan_to_num(nan=0.0)


def nan_to_num_hook(
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Returns a ``nan`` to num hook to replace ``nan`` with given number."""

def hook(g: torch.Tensor) -> torch.Tensor:
"""A zero ``nan`` hook to replace ``nan`` with given number."""
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)

return hook


def nan_to_num(
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
) -> GradientTransformation:
"""A gradient transformation that replaces gradient values of ``nan`` with given number.
Returns:
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params): # pylint: disable=unused-argument
return EmptyState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
if inplace:

def f(g):
return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)

else:

def f(g):
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)

new_updates = pytree.tree_map(f, updates)
return new_updates, state

return GradientTransformation(init_fn, update_fn)


def register_hook(hook) -> GradientTransformation:
Expand Down

0 comments on commit 7f4b991

Please # to comment.