From 7f4b9914a710803c700703f3f31c6cdda28c8681 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 18 Nov 2022 23:35:38 +0800 Subject: [PATCH] feat(hook): add `nan_to_num` hooks --- docs/source/api/api.rst | 4 ++++ torchopt/__init__.py | 3 +++ torchopt/hook.py | 49 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 97d8af307..7680e402e 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -186,12 +186,16 @@ Optimizer Hooks register_hook zero_nan_hook + nan_to_zero_hook + nan_to_zero Hook ~~~~ .. autofunction:: register_hook .. autofunction:: zero_nan_hook +.. autofunction:: nan_to_zero_hook +.. autofunction:: nan_to_zero ------ diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 40da95f35..fd408e3ab 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -31,6 +31,7 @@ from torchopt.alias import adam, adamw, rmsprop, sgd from torchopt.clip import clip_grad_norm from torchopt.combine import chain +from torchopt.hook import nan_to_num, register_hook from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( @@ -60,6 +61,8 @@ 'rmsprop', 'sgd', 'clip_grad_norm', + 'nan_to_num', + 'register_hook', 'chain', 'Optimizer', 'SGD', diff --git a/torchopt/hook.py b/torchopt/hook.py index 7dd9a66a1..1868e850f 100644 --- a/torchopt/hook.py +++ b/torchopt/hook.py @@ -14,18 +14,61 @@ # ============================================================================== """Hook utilities.""" +from typing import Callable, Optional + import torch from torchopt import pytree from torchopt.base import EmptyState, GradientTransformation -__all__ = ['zero_nan_hook', 'register_hook'] +__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'nan_to_num', 'register_hook'] def zero_nan_hook(g: torch.Tensor) -> torch.Tensor: - """Registers a zero nan hook to replace nan with zero.""" - return torch.where(torch.isnan(g), torch.zeros_like(g), g) + """A zero ``nan`` hook to replace ``nan`` with zero.""" + return g.nan_to_num(nan=0.0) + + +def nan_to_num_hook( + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None +) -> Callable[[torch.Tensor], torch.Tensor]: + """Returns a ``nan`` to num hook to replace ``nan`` with given number.""" + + def hook(g: torch.Tensor) -> torch.Tensor: + """A zero ``nan`` hook to replace ``nan`` with given number.""" + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + return hook + + +def nan_to_num( + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None +) -> GradientTransformation: + """A gradient transformation that replaces gradient values of ``nan`` with given number. + + Returns: + An ``(init_fn, update_fn)`` tuple. + """ + + def init_fn(params): # pylint: disable=unused-argument + return EmptyState() + + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument + if inplace: + + def f(g): + return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf) + + else: + + def f(g): + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) + + new_updates = pytree.tree_map(f, updates) + return new_updates, state + + return GradientTransformation(init_fn, update_fn) def register_hook(hook) -> GradientTransformation: