From 847549bba8a28c59c833913ac0f5207b097dc84b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 11 Sep 2022 18:41:54 +0800 Subject: [PATCH] chore: cleanup --- examples/MAML-RL/func_maml.py | 2 +- torchopt/_src/optimizer/func/base.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/MAML-RL/func_maml.py b/examples/MAML-RL/func_maml.py index ce3d9ef3..6413cc71 100644 --- a/examples/MAML-RL/func_maml.py +++ b/examples/MAML-RL/func_maml.py @@ -189,7 +189,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Reinforcement learning with ' 'Model-Agnostic Meta-Learning (MAML) - Train' + description='Reinforcement learning with Model-Agnostic Meta-Learning (MAML) - Train' ) parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') args = parser.parse_args() diff --git a/torchopt/_src/optimizer/func/base.py b/torchopt/_src/optimizer/func/base.py index b2c0ee90..9a4abb05 100644 --- a/torchopt/_src/optimizer/func/base.py +++ b/torchopt/_src/optimizer/func/base.py @@ -24,8 +24,8 @@ # mypy: ignore-errors class FuncOptimizer: # pylint: disable=too-few-public-methods - """A wrapper class to hold the functional optimizer. - It makes it easier to maintain the optimizer states. + """A wrapper class to hold the functional optimizer. It makes it easier to maintain the + optimizer states. See Also: - The functional Adam optimizer: :func:`torchopt.adam`. @@ -40,7 +40,7 @@ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> No Args: impl (GradientTransformation): A low level optimizer function, it could be a optimizer function provided by `alias.py` or a customized `chain` provided by `combine.py`. - inplace: (default: :data:`False`) + inplace (optional): (default: :data:`False`) The default value of ``inplace`` for each optimization update. """ self.impl = impl @@ -61,6 +61,9 @@ def step( loss that is used to compute the gradients to network parameters. params: (tree of torch.Tensor) An tree of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. + inplace (optional): (default: :data:`None`) + Wether to update the parameters in-place. If :data:`None`, use the default value + specified in the constructor. """ if self.optim_state is None: self.optim_state = self.impl.init(params)