From 3f466664d997ef3f9b0e910eaf692dc5a0cfaebc Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 2 Nov 2022 23:55:44 +0800 Subject: [PATCH] chore: resolve lint --- examples/iMAML/imaml_omniglot.py | 2 +- tests/test_implicit.py | 2 +- torchopt/diff/implicit/nn/module.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py index 65679dca..3e0bb364 100644 --- a/examples/iMAML/imaml_omniglot.py +++ b/examples/iMAML/imaml_omniglot.py @@ -51,7 +51,7 @@ class InnerNet( def __init__(self, meta_net, n_inner_iter, reg_param): super().__init__() self.meta_net = meta_net - self.net = copy.deepcopy(meta_net) + self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True) self.n_inner_iter = n_inner_iter self.reg_param = reg_param diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 6e6b9d60..06d180d4 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -243,7 +243,7 @@ class InnerNet(ImplicitMetaGradientModule, has_aux=True): def __init__(self, meta_model): super().__init__() self.meta_model = meta_model - self.model = copy.deepcopy(meta_model) + self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True) def forward(self, x): return self.model(x) diff --git a/torchopt/diff/implicit/nn/module.py b/torchopt/diff/implicit/nn/module.py index f0e0de5f..4ad48d32 100644 --- a/torchopt/diff/implicit/nn/module.py +++ b/torchopt/diff/implicit/nn/module.py @@ -81,9 +81,9 @@ def enable_implicit_gradients( raise ValueError('Implicit gradients are already enabled for the solve function.') cls_has_aux = cls.has_aux - custom_root_kwargs = dict(has_aux=cls_has_aux) - if cls.linear_solve is not None: - custom_root_kwargs.update(solve=cls.linear_solve) + custom_root_kwargs = dict(has_aux=cls_has_aux, solve=cls.linear_solve) + if cls.linear_solve is None: + custom_root_kwargs.pop('solve') @functools.wraps(cls_solve) def wrapped( # pylint: disable=too-many-locals @@ -145,7 +145,7 @@ def optimality_fn( ): container.update(container_backup) - @custom_root(optimality_fn, argnums=1, **custom_root_kwargs) + @custom_root(optimality_fn, argnums=1, **custom_root_kwargs) # type: ignore[arg-type] def solver_fn( flat_params: TupleOfTensors, # pylint: disable=unused-argument flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument