Skip to content

Commit

Permalink
chore: resolve lint
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 2, 2022
1 parent 886639b commit 3f46666
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/iMAML/imaml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class InnerNet(
def __init__(self, meta_net, n_inner_iter, reg_param):
super().__init__()
self.meta_net = meta_net
self.net = copy.deepcopy(meta_net)
self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
self.n_inner_iter = n_inner_iter
self.reg_param = reg_param

Expand Down
2 changes: 1 addition & 1 deletion tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class InnerNet(ImplicitMetaGradientModule, has_aux=True):
def __init__(self, meta_model):
super().__init__()
self.meta_model = meta_model
self.model = copy.deepcopy(meta_model)
self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True)

def forward(self, x):
return self.model(x)
Expand Down
8 changes: 4 additions & 4 deletions torchopt/diff/implicit/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def enable_implicit_gradients(
raise ValueError('Implicit gradients are already enabled for the solve function.')

cls_has_aux = cls.has_aux
custom_root_kwargs = dict(has_aux=cls_has_aux)
if cls.linear_solve is not None:
custom_root_kwargs.update(solve=cls.linear_solve)
custom_root_kwargs = dict(has_aux=cls_has_aux, solve=cls.linear_solve)
if cls.linear_solve is None:
custom_root_kwargs.pop('solve')

@functools.wraps(cls_solve)
def wrapped( # pylint: disable=too-many-locals
Expand Down Expand Up @@ -145,7 +145,7 @@ def optimality_fn(
):
container.update(container_backup)

@custom_root(optimality_fn, argnums=1, **custom_root_kwargs)
@custom_root(optimality_fn, argnums=1, **custom_root_kwargs) # type: ignore[arg-type]
def solver_fn(
flat_params: TupleOfTensors, # pylint: disable=unused-argument
flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument
Expand Down

0 comments on commit 3f46666

Please # to comment.