Skip to content

Commit

Permalink
Update sgdw for older pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Dec 11, 2023
1 parent 60b170b commit 711c5de
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions timm/optim/sgdw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from functools import update_wrapper, wraps
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
from torch.optim.optimizer import Optimizer
try:
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
has_recent_pt = True
except ImportError:
has_recent_pt = False

from typing import List, Optional

__all__ = ['SGDW', 'sgdw']
Expand Down Expand Up @@ -62,7 +69,9 @@ def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):

return has_sparse_grad

@_use_grad_for_differentiable
# FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
# without args, for backwards compatibility with old pytorch
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand Down Expand Up @@ -124,17 +133,19 @@ def sgdw(
See :class:`~torch.optim.SGD` for details.
"""
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
if foreach is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
else:
foreach = False

if foreach is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
else:
foreach = False

if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
else:
foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype

if foreach and not torch.jit.is_scripting():
func = _multi_tensor_sgdw
Expand Down

0 comments on commit 711c5de

Please # to comment.