diff --git a/torchopt/_src/update.py b/torchopt/_src/update.py index 56e509401..5628b0cb0 100644 --- a/torchopt/_src/update.py +++ b/torchopt/_src/update.py @@ -35,7 +35,9 @@ from torchopt._src import base -def apply_updates(params: base.Params, updates: base.Updates, inplace: bool = True) -> base.Params: +def apply_updates( + params: 'base.Params', updates: 'base.Updates', inplace: bool = True +) -> 'base.Params': """Applies an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters,