diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 00000000..1ee2f625
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,35 @@
+# https://editorconfig.org/
+
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+indent_style = space
+indent_size = 4
+trim_trailing_whitespace = true
+insert_final_newline = true
+
+[*.py]
+indent_size = 4
+src_paths=torchopt,tests,examples
+
+[*.md]
+indent_size = 2
+x-soft-wrap-text = true
+
+[*.rst]
+indent_size = 4
+x-soft-wrap-text = true
+
+[Makefile]
+indent_style = tab
+
+[*.cpp]
+indent_size = 2
+
+[*.h]
+indent_size = 2
+
+[*.cuh?]
+indent_size = 2
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
index a74f7620..9520f2ee 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -20,7 +20,7 @@ Please try to provide a minimal example to reproduce the bug. Error messages and
Please use the markdown code blocks for both code and stack traces.
```python
-import metarl
+import torchopt
```
```bash
@@ -43,8 +43,8 @@ Describe the characteristic of your environment:
* Versions of any other relevant libraries
```python
-import metarl, numpy, sys
-print(metarl.__version__, numpy.__version__, sys.version, sys.platform)
+import torchopt, numpy, sys
+print(torchopt.__version__, numpy.__version__, sys.version, sys.platform)
```
## Additional context
@@ -58,5 +58,5 @@ If you know or suspect the reason for this bug, paste the code lines and suggest
## Checklist
- [ ] I have checked that there is no similar issue in the repo (**required**)
-- [ ] I have read the [documentation](https://metarl.readthedocs.io/) (**required**)
+- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**)
- [ ] I have provided a minimal working example to reproduce the bug (**required**)
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 064d15bc..b19443c7 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -8,7 +8,7 @@ Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax `close #15213` if this solves the issue #15213
-- [ ] I have raised an issue to propose this change ([required](https://metarl.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes)
+- [ ] I have raised an issue to propose this change ([required](https://torchopt.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes)
## Types of changes
@@ -32,7 +32,7 @@ What types of changes does your code introduce? Put an `x` in all the boxes that
Go over all the following points, and put an `x` in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!
-- [ ] I have read the [CONTRIBUTION](https://metarl.readthedocs.io/en/latest/pages/contributing.html) guide (**required**)
+- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/pages/contributing.html) guide (**required**)
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
diff --git a/.gitignore b/.gitignore
index 5a67f740..87e9b834 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,8 +2,8 @@
.idea
build
__pycache__
-TorchOpt/**/*.so
-TorchOpt.egg-info
+torchopt/**/*.so
+torchopt.egg-info
dist
**/.ipynb_checkpoints/*
@@ -152,4 +152,4 @@ dmypy.json
.pytype/
# Cython debug symbols
-cython_debug/
\ No newline at end of file
+cython_debug/
diff --git a/CITATION.cff b/CITATION.cff
index 5c239556..fdfacfc4 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -27,4 +27,4 @@ authors:
version: 0.4.1
date-released: "2022-04-09"
license: Apache-2.0
-repository-code: "https://github.com/metaopt/TorchOpt"
+repository-code: "https://github.com/metaopt/torchopt"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 546a4f26..808d40c5 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -14,7 +14,7 @@
# ==============================================================================
cmake_minimum_required(VERSION 3.1)
-project(TorchOpt LANGUAGES CXX CUDA)
+project(torchopt LANGUAGES CXX CUDA)
find_package(CUDA REQUIRED)
diff --git a/Makefile b/Makefile
index fc0ade67..6e07d1a1 100644
--- a/Makefile
+++ b/Makefile
@@ -1,9 +1,9 @@
print-% : ; @echo $* = $($*)
SHELL = /bin/bash
-PROJECT_NAME = TorchOpt
+PROJECT_NAME = torchopt
PROJECT_PATH = ${PROJECT_NAME}/
PROJECT_FOLDER = $(PROJECT_NAME) examples include src tests
-PYTHON_FILES = $(shell find . -type f -name "*.py")
+PYTHON_FILES = $(shell find examples torchopt tests -type f -name "*.py" -o -name "*.pyi")
CPP_FILES = $(shell find . -type f -name "*.h" -o -name "*.cpp" -o -name "*.cuh" -o -name "*.cu")
COMMIT_HASH = $(shell git log -1 --format=%h)
COPYRIGHT = "MetaOPT Team. All Rights Reserved."
@@ -66,7 +66,8 @@ flake8: flake8-install
flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
py-format: py-format-install
- isort --check $(PYTHON_FILES) && yapf -ir $(PYTHON_FILES)
+ isort --project torchopt --check $(PYTHON_FILES) && \
+ yapf --in-place --recursive $(PYTHON_FILES)
mypy: mypy-install
mypy $(PROJECT_NAME)
@@ -103,4 +104,3 @@ format: py-format-install clang-format-install
yapf -ir $(PYTHON_FILES)
clang-format-11 -style=file -i $(CPP_FILES)
addlicense -c $(COPYRIGHT) -l apache -y 2022 $(PROJECT_FOLDER)
-
diff --git a/README.md b/README.md
index 4ceb9de3..24f53664 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
**TorchOpt** is a high-performance optimizer library built upon [PyTorch](https://pytorch.org/) for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:
-- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
+- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
- With the desgin of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.
--------------------------------------------------------------------------------
@@ -21,35 +21,36 @@ The README is organized as follows:
- [Installation](#installation)
- [Future Plan](#future-plan)
- [The Team](#the-team)
+- [Citing TorchOpt](#citing-torchopt)
## TorchOpt as Functional Optimizer
The desgin of TorchOpt follows the philosophy of functional programming. Aligned with [functorch](https://github.com/pytorch/functorch), users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook [Functional Optimizer](./tutorials/1_Functional_Optimizer.ipynb) for more details.
### Optax-Like API
-For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `TorchOpt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with functorch:
+For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with functorch:
```python
-import torch
-from torch import nn
-from torch import data
-from nn import functional as F
import functorch
-import TorchOpt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchopt
+from torch.utils.data import DataLoader
-class Net(nn.Module):...
+class Net(nn.Module): ...
-class Loader(data.DataLoader):...
+class Loader(DataLoader): ...
net = Net() # init
loader = Loader()
-optimizer = TorchOpt.adam()
+optimizer = torchopt.adam()
func, params = functorch.make_functional(net) # use functorch extract network parameters
opt_state = optimizer.init(params) # init optimizer
xs, ys = next(loader) # get data
pred = func(params, xs) # forward
-loss = F.cross_entropy(pred, ys) # compute loss
+loss = F.cross_entropy(pred, ys) # compute loss
grad = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grad, opt_state) # get updates
-params = TorchOpt.apply_updates(params, updates) # update network parameters
+params = torchopt.apply_updates(params, updates) # update network parameters
```
### PyTorch-Like API
We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by warpping our Optax-Like API for traditional PyTorch user:
@@ -57,7 +58,7 @@ We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by warpping o
```python
net = Net() # init
loader = Loader()
-optimizer = TorchOpt.Adam(net.parameters())
+optimizer = torchopt.Adam(net.parameters())
xs, ys = next(loader) # get data
pred = net(xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
@@ -71,15 +72,15 @@ On top of the same optimization function as `torch.optim`, an important benefit
# get updates
updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
# update network parameters
-params = TorchOpt.apply_updates(params, updates, inplace=False)
+params = torchopt.apply_updates(params, updates, inplace=False)
```
## TorchOpt as Differentiable Optimizer for Meta-Learning
-Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimisation process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations.
+Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimisation process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations.
-Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `nn.Module`) or optimizer (a.k.a. `optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [higher](https://github.com/facebookresearch/higher), [learn2learn](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API.
+Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `nn.Module`) or optimizer (a.k.a. `optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [higher](https://github.com/facebookresearch/higher), [learn2learn](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API.
In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. `torch.Tensor`). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion.
@@ -91,8 +92,8 @@ We hope meta-learning researchers could control the network parameters or optimi
### Meta-Learning API
-- We design a base class `TorchOpt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
-- We offer `TorchOpt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
+- We design a base class `torchopt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
+- We offer `torchopt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
- We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want.
- Some algorithms such as [MGRL](https://proceedings.neurips.cc/paper/2018/file/2715518c875999308842e3455eda2fe3-Paper.pdf) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function `stop_gradient` for manipulating the gradient graph, which is helpful for this kind of algortihms. Refer to the notebook [Stop Gradient](./tutorials/4_Stop_Gradient.ipynb) for more details.
@@ -101,40 +102,40 @@ We give an example of [MAML](https://arxiv.org/abs/1703.03400) with inner-loop A
```python
net = Net() # init
# the constructor `MetaOptimizer` takes as input the network
-inner_optim = TorchOpt.MetaAdam(net)
-outer_optim = TorchOpt.Adam(net.parameters())
+inner_optim = torchopt.MetaAdam(net)
+outer_optim = torchopt.Adam(net.parameters())
for train_iter in range(train_iters):
outer_loss = 0
for task in range(tasks):
loader = Loader(tasks)
-
+
# store states at the inital points
- net_state = TorchOpt.extract_state_dict(net) # extract state
- optim_state = TorchOpt.extract_state_dict(inner_optim)
+ net_state = torchopt.extract_state_dict(net) # extract state
+ optim_state = torchopt.extract_state_dict(inner_optim)
for inner_iter in range(inner_iters):
# compute inner loss and perform inner update
xs, ys = next(loader)
pred = net(xs)
- inner_loss = F.cross_entropy(pred, ys)
+ inner_loss = F.cross_entropy(pred, ys)
inner_optim.step(inner_loss)
# compute outer loss and back-propagate
- xs, ys = next(loader)
+ xs, ys = next(loader)
pred = net(xs)
outer_loss += F.cross_entropy(pred, ys)
-
+
# recover network and optimizer states at the inital point for the next task
- TorchOpt.recover_state_dict(inner_optim, optim_state)
- TorchOpt.recover_state_dict(net, net_state)
-
+ torchopt.recover_state_dict(inner_optim, optim_state)
+ torchopt.recover_state_dict(net, net_state)
+
outer_loss /= len(tasks) # task average
outer_optim.zero_grad()
outer_loss.backward()
outer_optim.step()
# stop gradient if necessary
- TorchOpt.stop_gradient(net)
- TorchOpt.stop_gradient(inner_optim)
+ torchopt.stop_gradient(net)
+ torchopt.stop_gradient(inner_optim)
```
## Examples
In *examples/*, we offer serveral examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms.
@@ -168,13 +169,13 @@ Requirements
- (Optional) For visualizing computation graphs
- [Graphviz](https://graphviz.org/download/) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`)
```bash
-pip install TorchOpt
+pip install torchopt
```
You can also build shared libraries from source, use:
```bash
-git clone git@github.com:metaopt/TorchOpt.git
-cd TorchOpt
+git clone git@github.com:metaopt/torchopt.git
+cd torchopt
python setup.py build_from_source
```
## Future Plan
@@ -196,6 +197,6 @@ If you find TorchOpt useful, please cite it in your publications.
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
- howpublished = {\url{https://github.com/metaopt/TorchOpt}},
+ howpublished = {\url{https://github.com/metaopt/torchopt}},
}
```
diff --git a/TorchOpt/__init__.py b/TorchOpt/__init__.py
deleted file mode 100644
index f42bd7c6..00000000
--- a/TorchOpt/__init__.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""TorchOpt: a high-performance optimizer library built upon PyTorch."""
-
-from TorchOpt._src import (
- accelerated_op_available,
- clip,
- combine,
- hook,
- schedule,
- visual,
-)
-from TorchOpt._src.alias import adam, rmsprop, sgd
-from TorchOpt._src.MetaOptimizer import (
- MetaAdam,
- MetaOptimizer,
- MetaRMSProp,
- MetaSGD,
-)
-from TorchOpt._src.Optimizer import SGD, Adam, Optimizer, RMSProp
-from TorchOpt._src.update import apply_updates
-from TorchOpt._src.utils import (
- extract_state_dict,
- recover_state_dict,
- stop_gradient,
-)
-
-__version__ = "0.4.1"
-
-__all__ = (
- "accelerated_op_available",
- "clip",
- "combine",
- "hook",
- "schedule",
- "visual",
- "adam",
- "rmsprop",
- "sgd",
- "MetaAdam",
- "MetaOptimizer",
- "MetaRMSProp",
- "MetaSGD",
- "SGD",
- "Adam",
- "Optimizer",
- "RMSProp",
- "apply_updates",
- "extract_state_dict",
- "recover_state_dict",
- "stop_gradient",
-)
diff --git a/TorchOpt/_lib/adam_op.py b/TorchOpt/_lib/adam_op.py
deleted file mode 100644
index ceb2eb9e..00000000
--- a/TorchOpt/_lib/adam_op.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-import torch
-
-
-def forward_(
- updates: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, b1: float,
- b2: float, eps: float, eps_root: float, count: int
-) -> torch.Tensor:
- ...
-
-
-def forwardMu(
- updates: torch.Tensor, mu: torch.Tensor, b1: float
-) -> torch.Tensor:
- ...
-
-
-def forwardNu(
- updates: torch.Tensor, nu: torch.Tensor, b2: float
-) -> torch.Tensor:
- ...
-
-
-def forwardUpdates(
- new_mu: torch.Tensor, new_nu: torch.Tensor, b1: float, b2: float, eps: float,
- eps_root: float, count: int
-) -> torch.Tensor:
- ...
-
-
-def backwardMu(
- dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float
-) -> torch.Tensor:
- ...
-
-
-def backwardNu(
- dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float
-) -> torch.Tensor:
- ...
-
-
-def backwardUpdates(
- dupdates: torch.Tensor, updates: torch.Tensor, new_mu: torch.Tensor,
- new_nu: torch.Tensor, b1: float, b2: float, count: int
-) -> torch.Tensor:
- ...
diff --git a/TorchOpt/_src/MetaOptimizer.py b/TorchOpt/_src/MetaOptimizer.py
deleted file mode 100644
index f4cbd045..00000000
--- a/TorchOpt/_src/MetaOptimizer.py
+++ /dev/null
@@ -1,189 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-from typing import Union
-
-import jax
-import torch
-from torch import nn
-
-from TorchOpt._src import base
-from TorchOpt._src.alias import adam, rmsprop, sgd
-from TorchOpt._src.pytypes import ScalarOrSchedule
-from TorchOpt._src.update import apply_updates
-
-
-class MetaOptimizer(object):
- """A high-level optimizer base class for meta learning."""
-
- def __init__(self, net: nn.Module, impl: base.GradientTransformation):
- """
- Args:
- net (nn.Module): a network whose parameters should be optimized.
- impl (base.GradientTransformation): a low level optimizer function, it could be a
- optimizer function provided by `alias.py` or a customerized `chain` provided by
- `combine.py`. Note that use `MetaOptimizer(sgd(moment_requires_grad=True))` or
- `MetaOptimizer(chain(sgd(moment_requires_grad=True))) is equavalent to `MetaSGD`.
- """
- self.impl = impl
- self.param_containers_groups = [] # type: ignore
- self.state_groups = [] # type: ignore
-
- self.add_param_group(net)
-
- def step(self, loss: torch.Tensor):
- """Compute the gradients of the loss to the network parameters and update network parameters.
-
- Graph of the derivative will be constructed, allowing to compute higher order derivative products.
- We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update
- the network parameters without modifying tensors in-place.
-
- Args:
- loss (torch.Tensor): the loss that is used to compute the gradients to the network parameters.
- """
- # step parameter only
- for idx, (state, param_containers) in enumerate(
- zip(self.state_groups, self.param_containers_groups)
- ):
- flatten_params, containers_tree = jax.tree_util.tree_flatten(
- param_containers
- )
- flatten_params = tuple(flatten_params)
- grad = torch.autograd.grad(
- loss, flatten_params, create_graph=True, allow_unused=True
- )
- updates, state = self.impl.update(grad, state, False)
- self.state_groups[idx] = state
- new_params = apply_updates(flatten_params, updates, inplace=False)
- unflatten_new_params = containers_tree.unflatten(new_params)
- for (container,
- unflatten_param) in zip(param_containers, unflatten_new_params):
- container.update(unflatten_param)
-
- def add_param_group(self, net):
- from TorchOpt.utils import _extract_container
- net_container = _extract_container(net, with_buffer=False)
- flatten_param, _ = jax.tree_util.tree_flatten(net_container)
- flatten_param = tuple(flatten_param)
- optim_state = self.impl.init(flatten_param)
- self.state_groups.append(optim_state)
- self.param_containers_groups.append(net_container)
-
- def state_dict(self):
- """Extract the references of the optimizer states.
-
- Note that the states are references, so any in-place operations will
- change the states inside `MetaOptimizer` at the same time.
- """
- out_groups = tuple(group for group in self.state_groups)
- return out_groups
-
- def load_state_dict(self, state_dict):
- self.state_groups = list(group for group in state_dict)
-
-
-class MetaSGD(MetaOptimizer):
- """A canonical Stochastic Gradient Descent optimiser."""
-
- def __init__(
- self,
- net: nn.Module,
- lr: ScalarOrSchedule,
- momentum: Union[float, None] = None,
- nesterov: bool = False,
- moment_requires_grad: bool = True
- ):
- """The `init` function.
- Args:
- net (nn.Module): a network whose parameters should be optimized.
- args: other arguments see `alias.sgd`, here we set `moment_requires_grad=True`
- to make tensors like momentum be differentiable.
- """
- super().__init__(
- net,
- sgd(
- lr=lr,
- momentum=momentum,
- nesterov=nesterov,
- moment_requires_grad=moment_requires_grad
- )
- )
-
-
-class MetaAdam(MetaOptimizer):
- """The classic Adam optimiser."""
-
- def __init__(
- self,
- net,
- lr: ScalarOrSchedule,
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- moment_requires_grad: bool = True,
- use_accelerated_op: bool = False
- ):
- """The `init` function.
- Args:
- net (nn.Module): a network whose parameters should be optimized.
- args: other arguments see `alias.adam`, here we set `moment_requires_grad=True`
- to make tensors like momentum be differentiable.
- """
- super().__init__(
- net,
- adam(
- lr=lr,
- b1=b1,
- b2=b2,
- eps=eps,
- eps_root=eps_root,
- moment_requires_grad=moment_requires_grad,
- use_accelerated_op=use_accelerated_op
- )
- )
-
-
-class MetaRMSProp(MetaOptimizer):
- """The classic RMSProp optimiser."""
-
- def __init__(
- self,
- net,
- lr: ScalarOrSchedule,
- decay: float = 0.9,
- eps: float = 1e-8,
- initial_scale: float = 0.,
- centered: bool = False,
- momentum: Union[float, None] = None,
- nesterov: bool = False
- ):
- """The `init` function.
- Args:
- net (nn.Module): a network whose parameters should be optimized.
- args: other arguments see `alias.adam`, here we set `moment_requires_grad=True`
- to make tensors like momentum be differentiable.
- """
- super().__init__(
- net,
- rmsprop(
- lr=lr,
- decay=decay,
- eps=eps,
- initial_scale=initial_scale,
- centered=centered,
- momentum=momentum,
- nesterov=nesterov
- )
- )
diff --git a/TorchOpt/_src/Optimizer.py b/TorchOpt/_src/Optimizer.py
deleted file mode 100644
index 8544d3da..00000000
--- a/TorchOpt/_src/Optimizer.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-from typing import Iterable, Union
-
-import jax
-import torch
-
-from TorchOpt._src import base
-from TorchOpt._src.alias import adam, rmsprop, sgd
-from TorchOpt._src.pytypes import ScalarOrSchedule
-from TorchOpt._src.update import apply_updates
-
-
-class Optimizer(object):
- """A high-level base class that has the similar with `torch.optim.Optimier`."""
-
- def __init__(self, params: Iterable, impl: base.GradientTransformation):
- """The `init` function.
-
- Args:
- params (iterable): an iterable of `torch.Tensor`s. Specifies what Tensors should be optimized.
- impl (base.GradientTransformation): a low level optimizer function, it could be
- a optimizer function provided by `alias.py` or a customerized `chain` provided by
- `combine.py`. Note that use `MetaOptimizer(sgd())` or `MetaOptimizer(chain(sgd()))
- is equavalent to `SGD`.
- """
- if not isinstance(params, list):
- params = list(params)
- self.impl = impl
- self.param_groups = [] # type: ignore
- self.param_tree_groups = [] # type: ignore
- self.state_groups = [] # type: ignore
- self.add_param_group(params)
-
- def zero_grad(self, set_to_none: bool = False):
- """Sets the gradients of all optimized `torch.Tensor`s to zero.
-
- The behivour is similar to `torch.optim.Optimizer.zero_grad`.
-
- Args:
- set_to_none (bool): instead of setting to zero, set the grads to None.
-
- """
- for group in self.param_groups:
- if set_to_none:
-
- def f(p):
- p.grad = None
- return None
- else:
-
- def f(p):
- if p.grad is None:
- return None
- if p.grad.grad_fn is not None:
- p.grad.detach_()
- else:
- p.grad.requires_grad_(False)
- p.grad.zero_()
- return None
-
- jax.tree_map(f, group)
-
- def state_dict(self):
- """Returns the state of the optimizer."""
- return self.state_groups
-
- def load_state_dict(self, state_dict):
- """Loads the optimizer state.
-
- Args:
- state_dict (dict): optimizer state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.state_groups = state_dict
-
- def step(self, closure=None):
- """Performs a single optimization step (parameter update).
-
- The behivour is similar to `torch.optim.Optimizer.step`.
-
- Args:
- closure (callable, optional): A closure that reevaluates the model and returns the loss.
-
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- for param, state in zip(self.param_groups, self.state_groups):
-
- def f(p):
- return p.grad
-
- grad = jax.tree_map(f, param)
- updates, _ = self.impl.update(grad, state)
- apply_updates(param, updates)
-
- return loss
-
- def add_param_group(self, params):
- params, tree = jax.tree_flatten(params)
- params = tuple(params)
- self.param_groups.append(params)
- self.param_tree_groups.append(tree)
- self.state_groups.append(self.impl.init(params))
-
-
-class SGD(Optimizer):
- """The classic Adam optimiser."""
-
- def __init__(
- self,
- params,
- lr: ScalarOrSchedule,
- momentum: Union[float, None] = None,
- nesterov: bool = False
- ):
- """The `init` function.
-
- Args:
- params (iterable): an iterable of `torch.Tensor`s. Specifies what Tensors should be optimized.
- args: other arguments see `alias.adam`.
-
- """
- super().__init__(
- params,
- sgd(
- lr=lr,
- momentum=momentum,
- nesterov=nesterov,
- moment_requires_grad=False
- )
- )
-
-
-class Adam(Optimizer):
- """A canonical Stochastic Gradient Descent optimiser."""
-
- def __init__(
- self,
- params,
- lr: ScalarOrSchedule,
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- use_accelerated_op: bool = False
- ):
- """The `init` function.
-
- Args:
- params (iterable): an iterable of `torch.Tensor`s. Specifies what Tensors should be optimized.
- args: other arguments see `alias.sgd`.
- """
- super().__init__(
- params,
- adam(
- lr=lr,
- b1=b1,
- b2=b2,
- eps=eps,
- eps_root=eps_root,
- moment_requires_grad=False,
- use_accelerated_op=use_accelerated_op
- )
- )
-
-
-class RMSProp(Optimizer):
- """An RMSProp optimiser."""
-
- def __init__(
- self,
- params,
- lr: ScalarOrSchedule,
- decay: float = 0.9,
- eps: float = 1e-8,
- initial_scale: float = 0.,
- centered: bool = False,
- momentum: Union[float, None] = None,
- nesterov: bool = False
- ):
- """The `init` function.
-
- Args:
- params (iterable): an iterable of `torch.Tensor`s. Specifies what Tensors should be optimized.
- args: other arguments see `alias.sgd`.
- """
- super().__init__(
- params,
- rmsprop(
- lr=lr,
- decay=decay,
- eps=eps,
- initial_scale=initial_scale,
- centered=centered,
- momentum=momentum,
- nesterov=nesterov
- )
- )
diff --git a/TorchOpt/_src/accelerated_op/adam_op/AdamOp.py b/TorchOpt/_src/accelerated_op/adam_op/AdamOp.py
deleted file mode 100644
index e726a61a..00000000
--- a/TorchOpt/_src/accelerated_op/adam_op/AdamOp.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import Any
-
-import torch
-
-from TorchOpt._lib import adam_op
-
-
-class AdamOp(object):
-
- class MuOp(torch.autograd.Function):
-
- @staticmethod
- def jvp(ctx: Any, *grad_inputs: Any) -> Any:
- pass
-
- @staticmethod
- def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
- updates, mu, b1 = args
- new_mu = adam_op.forwardMu(updates, mu, b1)
- ctx.save_for_backward(updates, mu)
- ctx.b1 = b1
- return new_mu
-
- @staticmethod
- def backward(ctx: Any, *args: Any) -> Any:
- dmu = args[0]
- updates, mu = ctx.saved_tensors
- b1 = ctx.b1
- result = adam_op.backwardMu(dmu, updates, mu, b1)
- return result[0], result[1], None
-
- class NuOp(torch.autograd.Function):
-
- @staticmethod
- def jvp(ctx: Any, *grad_inputs: Any) -> Any:
- pass
-
- @staticmethod
- def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
- updates, nu, b2 = args
- new_nu = adam_op.forwardNu(updates, nu, b2)
- ctx.save_for_backward(updates, nu)
- ctx.b2 = b2
- return new_nu
-
- @staticmethod
- def backward(ctx: Any, *args: Any) -> Any:
- dnu = args[0]
- updates, nu = ctx.saved_tensors
- b2 = ctx.b2
- result = adam_op.backwardNu(dnu, updates, nu, b2)
- return result[0], result[1], None
-
- class UpdatesOp(torch.autograd.Function):
-
- @staticmethod
- def jvp(ctx: Any, *grad_inputs: Any) -> Any:
- pass
-
- @staticmethod
- def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
- new_mu, new_nu, (b1, b2, eps, eps_root, count) = args
- new_updates = adam_op.forwardUpdates(
- new_mu, new_nu, b1, b2, eps, eps_root, count
- )
- ctx.save_for_backward(new_updates, new_mu, new_nu)
- ctx.others = (b1, b2, eps, eps_root, count)
- return new_updates
-
- @staticmethod
- def backward(ctx: Any, *args: Any) -> Any:
- dupdates = args[0]
- updates, new_mu, new_nu = ctx.saved_tensors
- b1, b2, eps, eps_root, count = ctx.others
- result = adam_op.backwardUpdates(
- dupdates, updates, new_mu, new_nu, b1, b2, count
- )
- return result[0], result[1], None
-
- def __init__(self, b1=0.9, b2=0.999, eps=1e-8, eps_root=0., inplace=True):
- self.b1 = b1
- self.b2 = b2
- self.eps = eps
- self.eps_root = eps_root
- self.inplace = inplace
-
- def __call__(self, mu, nu, updates, count):
- if updates is None:
- return mu, nu, None
- if updates.is_cuda:
- current_device = torch.cuda.current_device()
- torch.cuda.set_device(updates.device)
- if self.inplace:
- new_updates, new_mu, new_nu = adam_op.forward_(
- updates, mu, nu, self.b1, self.b2, self.eps, self.eps_root, count
- )
- else:
- new_mu = self.MuOp.apply(updates, mu, self.b1)
- new_nu = self.NuOp.apply(updates, nu, self.b2)
- new_updates = self.UpdatesOp.apply(
- new_mu, new_nu, (self.b1, self.b2, self.eps, self.eps_root, count)
- )
- if updates.is_cuda:
- torch.cuda.set_device(current_device)
- return new_mu, new_nu, new_updates
diff --git a/TorchOpt/_src/alias.py b/TorchOpt/_src/alias.py
deleted file mode 100644
index 3f676efe..00000000
--- a/TorchOpt/_src/alias.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py
-# ==============================================================================
-# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import Optional
-
-import jax
-
-from TorchOpt._src import base, combine, transform
-from TorchOpt._src.pytypes import ScalarOrSchedule
-
-
-def _scale_by_lr(lr: ScalarOrSchedule, flip_sign=True):
- m = -1 if flip_sign else 1
- if callable(lr):
-
- def schedule_wrapper(count):
-
- def f(scaled_lr):
- return m * scaled_lr
-
- return jax.tree_map(f, lr(count)) # type: ignore
-
- return transform.scale_by_schedule(schedule_wrapper)
- return transform.scale(m * lr)
-
-
-def adam(
- lr: ScalarOrSchedule,
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- moment_requires_grad: bool = False,
- use_accelerated_op: bool = False
-) -> base.GradientTransformation:
- """The classic Adam optimiser.
-
- Adam is an SGD variant with learning rate adaptation. The `lr`
- used for each weight is computed from estimates of first- and second-order
- moments of the gradients (using suitable exponential moving averages).
-
- References:
- Kingma et al, 2014: https://arxiv.org/abs/1412.6980
-
- Args:
- lr: this is a fixed global scaling factor.
- b1: the exponential decay rate to track the first moment of past gradients.
- b2: the exponential decay rate to track the second moment of past gradients.
- eps: a small constant applied to denominator outside of the square root
- (as in the Adam paper) to avoid dividing by zero when rescaling.
- eps_root: (default `0`), a small constant applied to denominator inside the
- square root (as in RMSProp), to avoid dividing by zero when rescaling.
- This is needed for example when computing (meta-)gradients through Adam.
- moment_requires_grad: (default `False`), if True the momentums will be created with flag
- `requires_grad=True`, this flag is often used in Meta Learning algorithms.
- use_accelerated_op: (default `False`), if True use our implemented fused operator.
-
- Returns:
- the corresponding `GradientTransformation`.
- """
- adam_inst = transform.scale_by_accelerated_adam if use_accelerated_op else transform.scale_by_adam
- return combine.chain(
- adam_inst(
- b1=b1,
- b2=b2,
- eps=eps,
- eps_root=eps_root,
- moment_requires_grad=moment_requires_grad
- ),
- _scale_by_lr(lr),
- )
-
-
-def sgd(
- lr: ScalarOrSchedule,
- momentum: Optional[float] = None,
- nesterov: bool = False,
- moment_requires_grad: bool = False,
-) -> base.GradientTransformation:
- """A canonical Stochastic Gradient Descent optimiser.
-
- This implements stochastic gradient descent. It also includes support for
- momentum, and nesterov acceleration, as these are standard practice when
- using stochastic gradient descent to train deep neural networks.
-
- References:
- Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
-
- Args:
- lr: this is a fixed global scaling factor.
- momentum: (default `None`), the `decay` rate used by the momentum term,
- when it is set to `None`, then momentum is not used at all.
- nesterov (default `False`): whether nesterov momentum is used.
- moment_requires_grad: (default `False`), if True the momentums will be created with flag
- `requires_grad=True`, this flag is often used in Meta Learning algorithms.
-
- Returns:
- A `GradientTransformation`.
- """
- return combine.chain(
- (
- transform.trace(
- decay=momentum,
- nesterov=nesterov,
- moment_requires_grad=moment_requires_grad
- ) if momentum is not None else base.identity()
- ), _scale_by_lr(lr)
- )
-
-
-def rmsprop(
- lr: ScalarOrSchedule,
- decay: float = 0.9,
- eps: float = 1e-8,
- initial_scale: float = 0.,
- centered: bool = False,
- momentum: Optional[float] = None,
- nesterov: bool = False
-) -> base.GradientTransformation:
- # pylint: disable=line-too-long
- """A flexible RMSProp optimiser.
- RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
- used for each weight is scaled by a suitable estimate of the magnitude of the
- gradients on previous steps. Several variants of RMSProp can be found
- in the literature. This alias provides an easy to configure RMSProp
- optimiser that can be used to switch between several of these variants.
- References:
- Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
- Graves, 2013: https://arxiv.org/abs/1308.0850
- Args:
- learning_rate: this is a fixed global scaling factor.
- decay: the decay used to track the magnitude of previous gradients.
- eps: a small numerical constant to avoid dividing by zero when rescaling.
- initial_scale: (default `0.`), initialisation of accumulators tracking the
- magnitude of previous updates. PyTorch uses `0`, TF1 uses `1`. When
- reproducing results from a paper, verify the value used by the authors.
- centered: (default `False`), whether the second moment or the variance of
- the past gradients is used to rescale the latest gradients.
- momentum: (default `None`), the `decay` rate used by the momentum term,
- when it is set to `None`, then momentum is not used at all.
- nesterov (default `False`): whether nesterov momentum is used.
- Returns:
- the corresponding `GradientTransformation`.
- """
- # pylint: enable=line-too-long
- if centered:
- return combine.chain(
- transform.scale_by_stddev(
- decay=decay, eps=eps, initial_scale=initial_scale
- ), _scale_by_lr(lr), (
- transform.trace(decay=momentum, nesterov=nesterov)
- if momentum is not None else base.identity()
- )
- )
- return combine.chain(
- transform.scale_by_rms(decay=decay, eps=eps, initial_scale=initial_scale),
- _scale_by_lr(lr), (
- transform.trace(decay=momentum, nesterov=nesterov)
- if momentum is not None else base.identity()
- )
- )
diff --git a/TorchOpt/_src/base.py b/TorchOpt/_src/base.py
deleted file mode 100644
index 8b1559e6..00000000
--- a/TorchOpt/_src/base.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/deepmind/optax/blob/master/optax/_src/base.py
-# ==============================================================================
-# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import Callable, NamedTuple, Tuple
-
-import typing_extensions
-
-from TorchOpt._src import pytypes
-
-OptState = pytypes.TensorTree # States are arbitrary nests of `torch.Tensor`.
-# Parameters are arbitrary nests of `torch.Tensor`.
-Params = pytypes.TensorTree
-Updates = Params # Gradient updates are of the same type as parameters.
-
-Schedule = Callable[[pytypes.Numeric], pytypes.Numeric]
-
-
-class EmptyState(NamedTuple):
- """An empty state for the simplest stateless transformations."""
-
-
-class TransformInitFn(typing_extensions.Protocol):
- """A callable type for the `init` step of a `GradientTransformation`.
-
- The `init` step takes a tree of `params` and uses these to construct an
- arbitrary structured initial `state` for the gradient transformation. This
- may hold statistics of the past updates or any other non static information.
- """
-
- def __call__(self, params: Params) -> OptState:
- """The `init` function.
-
- Args:
- params: The initial value of the parameters.
-
- Returns:
- The initial state of the gradient transformation.
- """
- ...
-
-
-class TransformUpdateFn(typing_extensions.Protocol):
- """A callable type for the `update` step of a `GradientTransformation`.
-
- The `update` step takes a tree of candidate parameter `updates` (e.g. their
- gradient with respect to some loss), an arbitrary structured `state`, and the
- current `params` of the model being optimised. The `params` argument is
- optional, it must however be provided when using transformations that require
- access to the current values of the parameters.
- """
-
- def __call__(self,
- updates: Updates,
- state: OptState,
- inplace: bool = True) -> Tuple[Updates, OptState]:
- """The `update` function.
-
- Args:
- updates: A tree of candidate updates.
- state: The state of the gradient transformation.
- inplace: (Optionally) if true, modify updates and state using inplace operations.
-
- Returns:
- The transformed updates, and the updated state.
- """
- ...
-
-
-class GradientTransformation(NamedTuple):
- """A pair of pure functions implementing a gradient transformation.
-
- TorchOpt optimizers are all implemented as _gradient transformations_ like
- Optax. A gradient transformation is defined to be a pair of pure functions,
- which are combined together in a `NamedTuple` so that they can be referred
- to by name.
-
- Since gradient transformations do not contain any internal state, all stateful
- optimizer properties (such as the current step count when using optimizer
- scheduels, or momemtum values) are passed through gradient transformations by
- using the optimizer _state_ pytree. Each time a gradient transformation is
- applied, the state is computed and returned, ready to be passed to the next
- call to the gradient transformation.
-
- Attributes:
- init: A pure function which, when called with an example instance of the
- parameters whose gradients will be transformed, returns a pytree
- containing the initial value for the optimizer state.
- update: A pure function which takes as input a pytree of updates (with the
- same tree structure as the original params pytree passed to init), the
- previous optimizer state (which may have been initialized using the init
- function), and optionally the inplace flag. The update function then
- returns the computed gradient updates, and a updates optimizer state.
- If the inplace flag is true, the output results are the same instance as
- the input.
- """
- init: TransformInitFn
- update: TransformUpdateFn
-
-
-def identity() -> GradientTransformation:
- """Stateless identity transformation that leaves input gradients untouched.
-
- This function passes through the *gradient updates* unchanged.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(_):
- return EmptyState()
-
- def update_fn(updates, state, inplace=False):
- return updates, state
-
- return GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/clip.py b/TorchOpt/_src/clip.py
deleted file mode 100644
index b0e24aed..00000000
--- a/TorchOpt/_src/clip.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py
-# ==============================================================================
-
-import jax
-import torch
-from torch._six import inf
-
-from TorchOpt._src import base
-
-ClipState = base.EmptyState
-
-
-def clip_grad_norm(
- max_norm: float,
- norm_type: float = 2.,
- error_if_nonfinite: bool = False
-) -> base.GradientTransformation:
- """Clips gradient norm of an iterable of parameters.
- Args:
- max_delta: The maximum absolute value for each element in the update.
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- del params
- return ClipState()
-
- def update_fn(updates, state, inplace=True):
- available_updates = []
- for g in updates:
- if g is not None:
- available_updates.append(g)
- if len(available_updates) == 0:
- return torch.tensor(0.)
- device = available_updates[0].device
- with torch.no_grad():
- if norm_type == inf:
- norms = [p.abs().max().to(device) for p in available_updates]
- total_norm = norms[0] if len(norms) == 1 else torch.max(
- torch.stack(norms)
- )
- else:
- total_norm = torch.norm(
- torch.stack(
- [torch.norm(p, norm_type).to(device) for p in available_updates]
- ), norm_type
- )
- if error_if_nonfinite and torch.logical_or(
- total_norm.isnan(), total_norm.isinf()
- ):
- raise RuntimeError(
- f'The total norm of order {norm_type} for gradients from '
- '`parameters` is non-finite, so it cannot be clipped. To disable '
- 'this error and scale the gradients by the non-finite norm anyway, '
- 'set `error_if_nonfinite=False`'
- )
- clip_coef = max_norm / (float(total_norm) + 1e-6)
- # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
- # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
- # when the gradients do not reside in CPU memory.
- clip_coef_clamped = min(clip_coef, 1.)
- if inplace:
-
- def f(g):
- return g.mul_(clip_coef_clamped) if g is not None else None
- else:
-
- def f(g):
- return g.mul(clip_coef_clamped) if g is not None else None
-
- new_updates = jax.tree_map(f, updates)
- return new_updates, state
-
- return base.GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/schedule.py b/TorchOpt/_src/schedule.py
deleted file mode 100644
index ad24cf82..00000000
--- a/TorchOpt/_src/schedule.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py
-# ==============================================================================
-# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import jax
-import numpy as np
-from absl import logging
-
-from TorchOpt._src import base, pytypes
-
-
-def polynomial_schedule(
- init_value: pytypes.Scalar,
- end_value: pytypes.Scalar,
- power: pytypes.Scalar,
- transition_steps: int,
- transition_begin: int = 0
-) -> base.Schedule:
- """Constructs a schedule with polynomial transition from init to end value.
- Args:
- init_value: initial value for the scalar to be annealed.
- end_value: end value of the scalar to be annealed.
- power: the power of the polynomial used to transition from init to end.
- transition_steps: number of steps over which annealing takes place,
- the scalar starts changing at `transition_begin` steps and completes
- the transition by `transition_begin + transition_steps` steps.
- If `transition_steps <= 0`, then the entire annealing process is disabled
- and the value is held fixed at `init_value`.
- transition_begin: must be positive. After how many steps to start annealing
- (before this many steps the scalar value is held fixed at `init_value`).
- Returns:
- schedule: A function that maps step counts to values.
- """
- if transition_steps <= 0:
- logging.info(
- 'A polynomial schedule was set with a non-positive `transition_steps` '
- 'value; this results in a constant schedule with value `init_value`.'
- )
- return lambda count: init_value
-
- if transition_begin < 0:
- logging.info(
- 'An exponential schedule was set with a negative `transition_begin` '
- 'value; this will result in `transition_begin` falling back to `0`.'
- )
- transition_begin = 0
-
- def schedule(count):
-
- def impl(count):
- count = np.clip(count - transition_begin, 0, transition_steps)
- frac = 1 - count / transition_steps
- return (init_value - end_value) * (frac**power) + end_value
-
- return jax.tree_map(impl, count)
-
- return schedule
-
-
-# Alias polynomial schedule to linear schedule for convenience.
-def linear_schedule(
- init_value: pytypes.Scalar,
- end_value: pytypes.Scalar,
- transition_steps: int,
- transition_begin: int = 0
-) -> base.Schedule:
- return polynomial_schedule(
- init_value=init_value,
- end_value=end_value,
- power=1,
- transition_steps=transition_steps,
- transition_begin=transition_begin
- )
diff --git a/TorchOpt/_src/transform.py b/TorchOpt/_src/transform.py
deleted file mode 100644
index 7cdc9c86..00000000
--- a/TorchOpt/_src/transform.py
+++ /dev/null
@@ -1,469 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
-# ==============================================================================
-# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import List, NamedTuple, Tuple, Union
-
-import jax
-import torch
-
-from TorchOpt._src import base
-from TorchOpt._src.pytypes import ScalarOrSchedule, Schedule
-
-ScaleState = base.EmptyState
-
-
-def inc_count(updates, count: Tuple[int]) -> Tuple[int]:
-
- def f(c, g):
- return c + 1 if g is not None else c
-
- return jax.tree_map(f, count, updates)
-
-
-def scale(step_size: float) -> base.GradientTransformation:
- """Scale updates by some fixed scalar `step_size`.
-
- Args:
- step_size: a scalar corresponding to a fixed scaling factor for updates.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- del params
- return ScaleState()
-
- def update_fn(updates, state, inplace=True):
- if inplace:
-
- def f(g):
- return g.mul_(step_size) if g is not None else None
- else:
-
- def f(g):
- return g.mul(step_size) if g is not None else None
-
- updates = jax.tree_map(f, updates)
- return updates, state
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-class ScaleByScheduleState(NamedTuple):
- """Maintains count for scale scheduling."""
- count: Tuple[int, ...] # type: ignore
-
-
-def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation:
- """Scale updates using a custom schedule for the `step_size`.
-
- Args:
- step_size_fn: a function that takes an update count as input and proposes
- the step_size to multiply the updates by.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- return ScaleByScheduleState(count=tuple(0 for _ in range(len(params))))
-
- def update_fn(updates, state, inplace=True):
- step_size = step_size_fn(state.count)
- if inplace:
- updates = jax.tree_map(
- lambda g, step_size: g.mul_(step_size), updates, step_size
- )
- else:
- updates = jax.tree_map(
- lambda g, step_size: g.mul(step_size), updates, step_size
- )
- return updates, ScaleByScheduleState(count=inc_count(updates, state.count))
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-def _update_moment(updates, moments, decay, order, inplace=True):
- """Compute the exponential moving average of the `order`-th moment."""
- if inplace:
-
- def f(g, t):
- return t.mul_(decay).add_(
- g**order, alpha=1 - decay
- ) if g is not None else t
- else:
-
- def f(g, t):
- return t.mul(decay).add(
- g**order, alpha=1 - decay
- ) if g is not None else t
-
- return jax.tree_map(f, updates, moments)
-
-
-def _update_moment_per_elem_norm(updates, moments, decay, order, inplace=True):
- """Compute the EMA of the `order`-th moment of the element-wise norm."""
-
- if inplace:
-
- def f(g, t):
- return t.mul_(decay).add_(
- g**order, alpha=1 - decay
- ) if g is not None else t
- else:
-
- def f(g, t):
- return t.mul(decay).add(
- g**order, alpha=1 - decay
- ) if g is not None else t
-
- return jax.tree_map(f, updates, moments)
-
-
-class ScaleByAdamState(NamedTuple):
- """State for the Adam algorithm."""
- count: Tuple[int, ...] # type: ignore
- mu: base.Updates
- nu: base.Updates
-
-
-def _bias_correction(moment, decay, count, inplace=True):
- """Perform bias correction. This becomes a no-op as count goes to infinity."""
- if inplace:
-
- def f(t, c):
- return t.div_(1 - decay**c)
- else:
-
- def f(t, c):
- return t.div(1 - decay**c)
-
- return jax.tree_map(f, moment, count)
-
-
-def scale_by_adam(
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- moment_requires_grad: bool = False,
-) -> base.GradientTransformation:
- """Rescale updates according to the Adam algorithm.
-
- References:
- [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
-
- Args:
- b1: decay rate for the exponentially weighted average of grads.
- b2: decay rate for the exponentially weighted average of squared grads.
- eps: term added to the denominator to improve numerical stability.
- eps_root: term added to the denominator inside the square-root to improve
- numerical stability when backpropagating gradients through the rescaling.
- moment_requires_grad: if true, states will be created with flag `requires_grad = True`.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- mu = jax.tree_map( # First moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
- params)
- nu = jax.tree_map( # Second moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
- params)
- return ScaleByAdamState(
- count=tuple(0 for _ in range(len(mu))), mu=tuple(mu), nu=tuple(nu)
- )
-
- def update_fn(updates, state, inplace=True):
- mu = _update_moment(updates, state.mu, b1, 1, inplace)
- nu = _update_moment_per_elem_norm(updates, state.nu, b2, 2, inplace)
- count_inc = inc_count(updates, state.count)
- mu_hat = _bias_correction(mu, b1, count_inc, False)
- nu_hat = _bias_correction(nu, b2, count_inc, False)
- if inplace:
-
- def f(g, m, v):
- return m.div_(
- torch.sqrt_(v.add_(eps_root)).add_(eps)
- ) if g is not None else None
- else:
-
- def f(g, m, v):
- return m.div(
- torch.sqrt(v.add(eps_root)).add(eps)
- ) if g is not None else None
-
- updates = jax.tree_map(f, updates, mu_hat, nu_hat)
- return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-def scale_by_accelerated_adam(
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- moment_requires_grad: bool = False,
-) -> base.GradientTransformation:
- """Rescale updates according to the Adam algorithm.
-
- This function is acceleracted by using some fused accelerated operators.
-
- References:
- [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
-
- Args:
- b1: decay rate for the exponentially weighted average of grads.
- b2: decay rate for the exponentially weighted average of squared grads.
- eps: term added to the denominator to improve numerical stability.
- eps_root: term added to the denominator inside the square-root to improve
- numerical stability when backpropagating gradients through the rescaling.
- moment_requires_grad: if true, states will be created with flag `requires_grad = True`.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
- from .accelerated_op import AdamOp
-
- def init_fn(params):
- mu = jax.tree_map( # First moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
- params)
- nu = jax.tree_map( # Second moment
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
- params)
- return ScaleByAdamState(
- count=tuple(0 for _ in range(len(params))), mu=mu, nu=nu
- )
-
- def update_fn(updates, state, inplace=True):
- count_inc = inc_count(updates, state.count)
- op = AdamOp(b1, b2, eps, eps_root, inplace)
- out = jax.tree_map(op, state.mu, state.nu, updates, count_inc)
- new_mus, new_nus, new_updates = [], [], []
- for new_mu, new_nu, new_update in out:
- new_mus.append(new_mu)
- new_nus.append(new_nu)
- new_updates.append(new_update)
- return tuple(new_updates), ScaleByAdamState(
- count=count_inc, mu=tuple(new_mus), nu=tuple(new_nus)
- )
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-class TraceState(NamedTuple):
- """Holds an aggregation of past updates."""
- trace: base.Params
-
-
-def trace(
- decay: float,
- nesterov: bool = False,
- moment_requires_grad: bool = False,
-) -> base.GradientTransformation:
- """Compute a trace of past updates.
-
- Note: `trace` and `ema` have very similar but distinct updates;
- `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`.
- Both are frequently found in the optimisation literature.
-
- Args:
- decay: the decay rate for the trace of past updates.
- nesterov: whether to use Nesterov momentum.
- moment_requires_grad: if true, states will be created with flag `requires_grad = True`.
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- if decay == 0.:
- return TraceState(trace=())
- else:
- return TraceState(
- trace=jax.tree_map(
- lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
- params
- )
- )
-
- def update_fn(updates, state, inplace=True):
- if nesterov:
- if inplace:
-
- def f1(g, t):
- return t.copy_(g.add(t, alpha=decay))
-
- def f2(g, t):
- return g.add_(t, alpha=decay)
-
- new_trace = jax.tree_map(f1, updates, state.trace)
- updates = jax.tree_map(f2, updates, new_trace)
- else:
-
- def f(g, t):
- return g.add(t, alpha=decay)
-
- new_trace = jax.tree_map(f, updates, state.trace)
- updates = jax.tree_map(f, updates, new_trace)
- else:
- if inplace:
-
- def f(g, t):
- return g.add_(t, alpha=decay)
-
- updates = jax.tree_map(f, updates, state.trace)
- state.trace.copy_(updates)
- new_trace = state.trace
- else:
-
- def f(g, t):
- return g.add(t, alpha=decay)
-
- updates = jax.tree_map(f, updates, state.trace)
- new_trace = updates
-
- return updates, TraceState(trace=new_trace)
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-class ScaleByRmsState(NamedTuple):
- """State for exponential root mean-squared (RMS)-normalized updates."""
- nu: base.Updates
-
-
-def scale_by_rms(
- decay: float = 0.9,
- eps: float = 1e-8,
- initial_scale: float = 0.
-) -> base.GradientTransformation:
- """Rescale updates by the root of the exp. moving avg of the square.
-
- References:
- [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
-
- Args:
- decay: decay rate for the exponentially weighted average of squared grads.
- eps: term added to the denominator to improve numerical stability.
- initial_scale: initial value for second moment
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- nu = jax.tree_map(
- lambda n: torch.full_like(n, initial_scale), params
- ) # second moment
- return ScaleByRmsState(nu=nu)
-
- def update_fn(updates, state, inplace=True):
- nu = _update_moment_per_elem_norm(updates, state.nu, decay, 2, inplace)
- if inplace:
-
- def f(g, n):
- return g.mul_(torch.rsqrt(n.add(eps)))
- else:
-
- def f(g, n):
- return g.mul(torch.rsqrt(n.add(eps)))
-
- # """The followings are pytorch style"""
- # if inplace:
- # def f(g, n): return g.div_(torch.sqrt_(n).add_(eps))
- # else:
- # def f(g, n): return g.div(torch.sqrt(n).add(eps))
- updates = jax.tree_map(f, updates, nu)
- return updates, ScaleByRmsState(nu=nu)
-
- return base.GradientTransformation(init_fn, update_fn)
-
-
-class ScaleByRStdDevState(NamedTuple):
- """State for centered exponential moving average of squares of updates."""
- mu: base.Updates
- nu: base.Updates
-
-
-def scale_by_stddev(
- decay: float = 0.9,
- eps: float = 1e-8,
- initial_scale: float = 0.
-) -> base.GradientTransformation:
- """Rescale updates by the root of the centered exp. moving average of squares.
-
- References:
- [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
-
- Args:
- decay: decay rate for the exponentially weighted average of squared grads.
- eps: term added to the denominator to improve numerical stability.
- initial_scale: initial value for second moment
-
- Returns:
- An (init_fn, update_fn) tuple.
- """
-
- def init_fn(params):
- mu = jax.tree_map(torch.zeros_like, params) # First moment
- nu = jax.tree_map(
- lambda n: torch.full_like(n, initial_scale), params
- ) # second moment
- return ScaleByRStdDevState(mu=mu, nu=nu)
-
- def update_fn(updates, state, inplace=True):
- mu = _update_moment(updates, state.mu, decay, 1, inplace)
- nu = _update_moment_per_elem_norm(updates, state.nu, decay, 2, inplace)
- if inplace:
-
- def f(g, m, n):
- return g.mul_(torch.rsqrt(n.sub(m**2).add(eps)))
- else:
-
- def f(g, m, n):
- return g.mul(torch.rsqrt(n.sub(m**2).add(eps)))
-
- # """The followings are pytorch style"""
- # if inplace:
- # def f(g, m, n): return g.div_(torch.sqrt_(n.sub_(m ** 2)).add(eps))
- # else:
- # def f(g, m, n): return g.div(torch.sqrt(n.sub(m ** 2)).add(eps))
- updates = jax.tree_map(f, updates, mu, nu)
- return updates, ScaleByRStdDevState(mu=mu, nu=nu)
-
- return base.GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/utils.py b/TorchOpt/_src/utils.py
deleted file mode 100644
index 23c28ae9..00000000
--- a/TorchOpt/_src/utils.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import Dict, List, NamedTuple, Union
-
-import jax
-import torch
-from torch import nn
-
-from TorchOpt._src.MetaOptimizer import MetaOptimizer
-
-
-class _ModuleState(NamedTuple):
- params: List[Dict]
-
- visual_contents: Union[None, Dict] = None
-
-
-# mypy: ignore-errors
-def stop_gradient(target):
- """Stop the gradient for the input object.
-
- Since a tensor use `grad_fn` to connect itself with the previous computation
- graph, the back-propagated gradient will flow over the tensor and continue
- flow to the tensors that is connected by `grad_fn`. Some algorithms requires
- manually detaching tensors from the computation graph.
-
- Note that the stop_gradient operation is in-place.
-
- Args:
- target: the target that to be detached from the computation graph, it coule
- be a `nn.Module`, `TorchOpt.MetaOptimizer`, state of the
- `TorchOpt.MetaOptimizer`, or just a plain list of tensors.
- inplace: if True, the target will be detached in-place. if False, this function
- will return a detached copy of the target. The in-place operation is fast
- and memory efficient but may raise back-propagation error.
- """
-
- def f(obj):
- if isinstance(obj, torch.Tensor):
- requires_grad = obj.requires_grad
- obj.detach_().requires_grad_(requires_grad)
- return None
-
- if isinstance(target, _ModuleState):
- true_target = target.params
- elif isinstance(target, nn.Module):
- true_target = tuple(target.parameters())
- elif isinstance(target, MetaOptimizer):
- true_target, _ = jax.tree_flatten(target.state_dict())
- else:
- true_target = target
-
- jax.tree_map(f, true_target)
-
-
-def extract_state_dict(
- mod, copy=False, *, with_buffer=True, enable_visual=False, visual_prefix=''
-):
- """Extract target state.
-
- Since a tensor use `grad_fn` to connect itself with the previous computation
- graph, the back-propagated gradient will flow over the tensor and continue
- flow to the tensors that is connected by `grad_fn`. Some algorithms requires
- manually detaching tensors from the computation graph.
-
- Note that the extracted state is a reference, which means any in-place operatior
- will affect the target that the state is extracted from.
-
- Args:
- mod: it coule be a `nn.Module` or `TorchOpt.MetaOptimizer`.
- with_buffer: extract buffer together with parameters, this argument is only
- used if the input target is `nn.Module`.
- enable_visual: add additional annoations, which could be used in computation
- graph visualization. Currently, this flag only has effect on `nn.Module` but
- we will support `TorchOpt.MetaOptimizer` later.
- visual_prefix: prefix for the visualization annoations.
-
- Returns:
- State extracted of the input object.
- """
- if isinstance(mod, nn.Module):
- if enable_visual:
- visual_contents = {}
-
- for k, v in mod.named_parameters():
- if v.grad_fn is not None:
- visual_contents.update({v.grad_fn: (visual_prefix + k, v)})
- else:
- visual_contents.update({v: visual_prefix + k})
- else:
- visual_contents = None
-
- params = []
-
- def get_v(v):
- if copy:
- requires_grad = v.requires_grad
- return v.clone().detach_().requires_grad_(requires_grad)
- else:
- return v
-
- def _update(term):
- if len(term) != 0:
- params.append({k: get_v(v) for k, v in term.items()})
-
- _update(mod._parameters)
- if with_buffer:
- _update(mod._buffers)
- for module in mod.modules():
- if module is mod:
- continue
- _update(module._parameters)
- if with_buffer:
- _update(module._buffers)
- return _ModuleState(params=tuple(params), visual_contents=visual_contents)
- elif isinstance(mod, MetaOptimizer):
- state = mod.state_dict()
- if copy:
- flatten_state, state_tree = jax.tree_flatten(state)
-
- def get_v(v):
- if not isinstance(v, torch.Tensor):
- return v
- requires_grad = v.requires_grad
- return v.clone().detach_().requires_grad_(requires_grad)
-
- flatten_state = jax.tree_map(get_v, flatten_state)
- return state_tree.unflatten(flatten_state)
- else:
- return state
-
- else:
- raise RuntimeError(f"Unexpected class of {mod}")
-
-
-def _extract_container(mod, with_buffer=True):
- if isinstance(mod, nn.Module):
- containers = []
-
- def _update(term):
- if len(term) != 0:
- containers.append(term)
-
- _update(mod._parameters)
- if with_buffer:
- _update(mod._buffers)
- for module in mod.modules():
- if module is mod:
- continue
- _update(module._parameters)
- if with_buffer:
- _update(module._buffers)
- return tuple(containers)
- else:
- raise RuntimeError(f"Unexpected class of {mod}")
-
-
-def recover_state_dict(mod, state):
- """Recover state.
-
- This function is compatiable for the `extract_state`.
-
- Note that the recovering process is not in-place, so the tensors of the object
- will not be modified.
-
- Args:
- mod: targe that need to recover.
- state: the recovering state.
- """
- if isinstance(mod, nn.Module):
- target_container = _extract_container(mod)
- for target, source in zip(target_container, state.params):
- target.update(source)
- elif isinstance(mod, MetaOptimizer):
- mod.load_state_dict(state)
- else:
- raise RuntimeError(f"Unexpected class of {mod}")
diff --git a/TorchOpt/_src/visual.py b/TorchOpt/_src/visual.py
deleted file mode 100644
index aa3e9702..00000000
--- a/TorchOpt/_src/visual.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright 2022 MetaOPT Team. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-# This file is modified from:
-# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
-# ==============================================================================
-
-import warnings
-from collections import namedtuple
-from distutils.version import LooseVersion
-from typing import Dict, Generator
-
-import torch
-from graphviz import Digraph
-
-Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))
-
-# Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
-SAVED_PREFIX = "_saved_"
-
-
-def get_fn_name(fn, show_attrs, max_attr_chars):
- name = str(type(fn).__name__)
- if not show_attrs:
- return name
- attrs = dict()
- for attr in dir(fn):
- if not attr.startswith(SAVED_PREFIX):
- continue
- val = getattr(fn, attr)
- attr = attr[len(SAVED_PREFIX):]
- if torch.is_tensor(val):
- attrs[attr] = "[saved tensor]"
- elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val):
- attrs[attr] = "[saved tensors]"
- else:
- attrs[attr] = str(val)
- if not attrs:
- return name
- max_attr_chars = max(max_attr_chars, 3)
- col1width = max(len(k) for k in attrs.keys())
- col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars)
- sep = "-" * max(col1width + col2width + 2, len(name))
- attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's'
-
- def truncate(s):
- return s[:col2width - 3] + "..." if len(s) > col2width else s
-
- params = '\n'.join(
- attrstr % (k, truncate(str(v))) for (k, v) in attrs.items()
- )
- return name + '\n' + sep + '\n' + params
-
-
-# mypy: ignore-errors
-def make_dot(
- var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50
-):
- """Produces Graphviz representation of PyTorch autograd graph.
-
- If a node represents a backward function, it is gray. Otherwise, the node
- represents a tensor and is either blue, orange, or green:
- - Blue: reachable leaf tensors that requires grad (tensors whose `.grad`
- fields will be populated during `.backward()`)
- - Orange: saved tensors of custom autograd functions as well as those
- saved by built-in backward nodes
- - Green: tensor passed in as outputs
- - Dark green: if any output is a view, we represent its base tensor with
- a dark green node.
-
- Args:
- var: output tensor
- params: [dict of (name, tensor) or state_dict] to add names to node that requires grad
- show_attrs: whether to display non-tensor attributes of backward nodes
- (Requires PyTorch version >= 1.9)
- show_saved: whether to display saved tensor nodes that are not by custom
- autograd functions. Saved tensor nodes for custom functions, if
- present, are always displayed. (Requires PyTorch version >= 1.9)
- max_attr_chars: if show_attrs is `True`, sets max number of characters
- to display for any given attribute.
- """
- if LooseVersion(torch.__version__) < LooseVersion("1.9") and \
- (show_attrs or show_saved):
- warnings.warn(
- "make_dot: showing grad_fn attributes and saved variables"
- " requires PyTorch version >= 1.9. (This does NOT apply to"
- " saved tensors saved by custom autograd functions.)"
- )
-
- param_map = {}
-
- if params is not None:
- from TorchOpt.utils import _ModuleState
- if isinstance(params, _ModuleState):
- param_map.update(params.visual_contents)
- elif isinstance(params, Dict):
- param_map.update({v: k for k, v in params.items()})
- elif isinstance(params, Generator):
- param_map.update({v: k for k, v in params})
- else:
- for param in params:
- if isinstance(param, _ModuleState):
- param_map.update(param.visual_contents)
- elif isinstance(param, Generator):
- param_map.update({v: k for k, v in param})
- else:
- param_map.update({v: k for k, v in param.items()})
-
- node_attr = dict(
- style='filled',
- shape='box',
- align='left',
- fontsize='10',
- ranksep='0.1',
- height='0.2',
- fontname='monospace'
- )
- dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
- seen = set()
-
- def size_to_str(size):
- return '(' + (', ').join(['%d' % v for v in size]) + ')'
-
- def get_var_name(var, name=None):
- if not name:
- name = param_map[var] if var in param_map else ''
- return '%s\n %s' % (name, size_to_str(var.size()))
-
- def get_var_name_with_flag(var):
- if var in param_map:
- return '%s\n %s' % (
- param_map[var][0], size_to_str(param_map[var][1].size())
- )
- else:
- return None
-
- def add_nodes(fn):
- assert not torch.is_tensor(fn)
- if fn in seen:
- return
- seen.add(fn)
-
- if show_saved:
- for attr in dir(fn):
- if not attr.startswith(SAVED_PREFIX):
- continue
- val = getattr(fn, attr)
- seen.add(val)
- attr = attr[len(SAVED_PREFIX):]
- if torch.is_tensor(val):
- dot.edge(str(id(fn)), str(id(val)), dir="none")
- dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange')
- if isinstance(val, tuple):
- for i, t in enumerate(val):
- if torch.is_tensor(t):
- name = attr + '[%s]' % str(i)
- dot.edge(str(id(fn)), str(id(t)), dir="none")
- dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange')
-
- if hasattr(fn, 'variable'):
- # if grad_accumulator, add the node for `.variable`
- var = fn.variable
- seen.add(var)
- dot.node(str(id(var)), get_var_name(var), fillcolor='lightblue')
- dot.edge(str(id(var)), str(id(fn)))
-
- fn_name = get_fn_name(fn, show_attrs, max_attr_chars)
- fn_fillcolor = None
- var_name = get_var_name_with_flag(fn)
- if var_name is not None:
- fn_name = '%s\n %s' % (fn_name, var_name)
- fn_fillcolor = 'lightblue'
-
- # add the node for this grad_fn
- dot.node(str(id(fn)), fn_name, fillcolor=fn_fillcolor)
-
- # recurse
- if hasattr(fn, 'next_functions'):
- for u in fn.next_functions:
- if u[0] is not None:
- dot.edge(str(id(u[0])), str(id(fn)))
- add_nodes(u[0])
-
- # note: this used to show .saved_tensors in pytorch0.2, but stopped
- # working* as it was moved to ATen and Variable-Tensor merged
- # also note that this still works for custom autograd functions
- if hasattr(fn, 'saved_tensors'):
- for t in fn.saved_tensors:
- dot.edge(str(id(t)), str(id(fn)))
- dot.node(str(id(t)), get_var_name(t), fillcolor='orange')
-
- def add_base_tensor(var, color='darkolivegreen1'):
- if var in seen:
- return
- seen.add(var)
- dot.node(str(id(var)), get_var_name(var), fillcolor=color)
- if (var.grad_fn):
- add_nodes(var.grad_fn)
- dot.edge(str(id(var.grad_fn)), str(id(var)))
- if var._is_view():
- add_base_tensor(var._base, color='darkolivegreen3')
- dot.edge(str(id(var._base)), str(id(var)), style="dotted")
-
- # handle multiple outputs
- if isinstance(var, tuple):
- for v in var:
- add_base_tensor(v)
- else:
- add_base_tensor(var)
-
- resize_graph(dot)
-
- return dot
-
-
-def resize_graph(dot, size_per_element=0.15, min_size=12):
- """Resize the graph according to how much content it contains.
- Modify the graph in place.
- """
- # Get the approximate number of nodes and edges
- num_rows = len(dot.body)
- content_size = num_rows * size_per_element
- size = max(min_size, content_size)
- size_str = str(size) + "," + str(size)
- dot.graph_attr.update(size=size_str)
diff --git a/docker/dev.dockerfile b/docker/dev.dockerfile
index 6c86fee0..01c00a0e 100644
--- a/docker/dev.dockerfile
+++ b/docker/dev.dockerfile
@@ -3,8 +3,8 @@
CPU_PARENT=ubuntu:18.04
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
-TAG=metaopt/TorchOpt
-VERSION=$(cat ./stable_baselines3/version.txt)
+TAG=metaopt/torchopt
+VERSION=$(shell git log -1 --format=%h)
if [[ ${USE_GPU} == "True" ]]; then
PARENT=${GPU_PARENT}
diff --git a/docs/conf.py b/docs/conf.py
index 8dfa64e6..4b42352a 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,10 +20,10 @@
def get_version() -> str:
- # https://packaging.python.org/guides/single-sourcing-package-version/
- with open(os.path.join("..", "TorchOpt", "__init__.py"), "r") as f:
- init = f.read().split()
- return init[init.index("__version__") + 2][1:-1]
+ # https://packaging.python.org/guides/single-sourcing-package-version/
+ with open(os.path.join("..", "torchopt", "__init__.py"), "r") as f:
+ init = f.read().split()
+ return init[init.index("__version__") + 2][1:-1]
# -- Project information -----------------------------------------------------
@@ -41,7 +41,7 @@ def get_version() -> str:
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- "sphinx.ext.autodoc",
+ "sphinx.ext.autodoc",
]
# Add any paths that contain templates here, relative to this directory.
@@ -74,8 +74,8 @@ def get_version() -> str:
def setup(app):
- app.add_js_file("js/copybutton.js")
- app.add_css_file("css/style.css")
+ app.add_js_file("js/copybutton.js")
+ app.add_css_file("css/style.css")
# -- Extension configuration -------------------------------------------------
diff --git a/docs/index.rst b/docs/index.rst
index c7781713..90bf6a38 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,12 +1,12 @@
-:github_url: https://github.com/metaopt/TorchOpt/tree/main/docs
+:github_url: https://github.com/metaopt/torchopt/tree/main/docs
TorchOpt
--------
**TorchOpt** is a high-performance optimizer library built upon `PyTorch `_ for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:
-* TorchOpt provides functional optimizer which enables `JAX-like `_ composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to `Optax `_ in JAX.
-* With the desgin of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.
+* TorchOpt provides functional optimizer which enables `JAX-like `_ composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to `Optax `_ in JAX.
+* With the desgin of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.
Installation
------------
@@ -18,14 +18,14 @@ Requirements
.. code-block:: bash
- pip install TorchOpt
+ pip install torchopt
You can also build shared libraries from source, use:
.. code-block:: bash
- git clone git@github.com:metaopt/TorchOpt.git
- cd TorchOpt
+ git clone git@github.com:metaopt/torchopt.git
+ cd torchopt
python setup.py build_from_source
The Team
@@ -37,10 +37,9 @@ Support
-------
If you are having issues, please let us know by filing an issue on our
-`issue tracker `_.
-
+`issue tracker `_.
License
-------
-TorchOpt is licensed under the Apache 2.0 License.
\ No newline at end of file
+TorchOpt is licensed under the Apache 2.0 License.
diff --git a/examples/L2R/README.md b/examples/L2R/README.md
index e2f8007e..8528fe24 100644
--- a/examples/L2R/README.md
+++ b/examples/L2R/README.md
@@ -1,14 +1,16 @@
# Learning-to-reweight-examples
-Code On Mnist reweighting example in paper [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050)] using `TorchOpt`. The idea of L2R is to use virtual update of inner-loop neural network optimisation to meta-learn the reweighting parameters for robust deep learning. We use `MetaSGD` as the inner-loop optimiser.
+Code On Mnist reweighting example in paper [Learning to Reweight Examples for Robust Deep Learning](https://arxiv.org/abs/1803.09050)] using TorchOpt. The idea of L2R is to use virtual update of inner-loop neural network optimisation to meta-learn the reweighting parameters for robust deep learning. We use `MetaSGD` as the inner-loop optimiser.
+
+## Usage
-# Usage
We use traditional supervised training as the baseline.
+
```bash
### Run both algorithms and conduct comparison
python3 train_l2r.py --algo both
-### For baseline
+### For baseline
python3 train_l2r.py --algo baseline
### For L2R algorithm
@@ -16,8 +18,9 @@ python3 train_l2r.py --algo l2r
```
# Results
+
The test accuracy comparison between baseline and L2R validate the effectiveness of algorithms.
+
-

+
-
diff --git a/examples/L2R/helper/argument.py b/examples/L2R/helper/argument.py
index e29bdb0a..1440f27a 100644
--- a/examples/L2R/helper/argument.py
+++ b/examples/L2R/helper/argument.py
@@ -15,36 +15,25 @@
import argparse
-import torch
-
def parse_args():
- parser = argparse.ArgumentParser([], description='L2R')
-
- parser.add_argument('--seed', type=int, default=42)
- parser.add_argument('--epoch', type=int, default=30, help='Training Epoch')
-
- parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
- parser.add_argument(
- '--pos_ratio',
- type=float,
- default=0.995,
- help='Ratio of positive examples in training'
- )
- parser.add_argument(
- '--ntest', type=int, default=500, help='Number of testing examples'
- )
- parser.add_argument(
- '--ntrain', type=int, default=5000, help='Number of testing examples'
- )
- parser.add_argument(
- '--nval', type=int, default=10, help='Number of valid examples'
- )
- parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
-
- ### For baseline
- parser.add_argument('--algo', type=str, default='both')
-
- args = parser.parse_args()
- # use the GPU if available
- return args
+ parser = argparse.ArgumentParser([], description='L2R')
+
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--epoch', type=int, default=30, help='Training Epoch')
+
+ parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
+ parser.add_argument(
+ '--pos_ratio', type=float, default=0.995, help='Ratio of positive examples in training'
+ )
+ parser.add_argument('--ntest', type=int, default=500, help='Number of testing examples')
+ parser.add_argument('--ntrain', type=int, default=5000, help='Number of testing examples')
+ parser.add_argument('--nval', type=int, default=10, help='Number of valid examples')
+ parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
+
+ # For baseline
+ parser.add_argument('--algo', type=str, default='both')
+
+ args = parser.parse_args()
+ # use the GPU if available
+ return args
diff --git a/examples/L2R/helper/model.py b/examples/L2R/helper/model.py
index 5a3ff2fa..d3a0beac 100644
--- a/examples/L2R/helper/model.py
+++ b/examples/L2R/helper/model.py
@@ -28,54 +28,49 @@
#
# Models for MNIST experiments.
#
-from __future__ import division, print_function
-import numpy as np
import torch
import torch.nn as nn
class LeNet5(nn.Module):
- def __init__(self, args):
- super(LeNet5, self).__init__()
- self.model = nn.Sequential(
- nn.Conv2d(1, 16, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 5),
- nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(512, 128), nn.ReLU(),
- nn.Linear(128, 1), nn.Sigmoid()
- )
- self.args = args
- self.meta_weights = torch.zeros(
- self.args.batch_size, requires_grad=True
- ).to(self.args.device)
- self.criterion = nn.BCELoss()
+ def __init__(self, args):
+ super(LeNet5, self).__init__()
+ self.model = nn.Sequential(
+ nn.Conv2d(1, 16, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 5), nn.ReLU(),
+ nn.MaxPool2d(2), nn.Flatten(), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 1),
+ nn.Sigmoid()
+ )
+ self.args = args
+ self.meta_weights = torch.zeros(
+ self.args.batch_size, requires_grad=True
+ ).to(self.args.device)
+ self.criterion = nn.BCELoss()
- def forward(self, x):
- return self.model(x).squeeze(dim=-1)
+ def forward(self, x):
+ return self.model(x).squeeze(dim=-1)
- def reset_meta(self, size):
- self.meta_weights = torch.zeros(
- size, requires_grad=True
- ).to(self.args.device)
+ def reset_meta(self, size):
+ self.meta_weights = torch.zeros(size, requires_grad=True).to(self.args.device)
- def normalise(self):
- self.meta_weights = self.meta_weights.detach()
- weights_sum = torch.sum(self.meta_weights)
- weights_sum = weights_sum + 1 if weights_sum == 0 else weights_sum
- self.meta_weights /= weights_sum
+ def normalise(self):
+ self.meta_weights = self.meta_weights.detach()
+ weights_sum = torch.sum(self.meta_weights)
+ weights_sum = weights_sum + 1 if weights_sum == 0 else weights_sum
+ self.meta_weights /= weights_sum
- def inner_loss(self, train_x, train_y):
- result = self.forward(train_x)
+ def inner_loss(self, train_x, train_y):
+ result = self.forward(train_x)
- # manually implement bce_loss to make the loss differentiable w.r.t self.meta_weights
- loss = -(
- train_y * torch.log(result + 1e-10) +
- (1 - train_y) * torch.log(1 - result + 1e-10)
- )
- weighted_loss = torch.sum(self.meta_weights * loss)
- return weighted_loss
+ # manually implement bce_loss to make the loss differentiable w.r.t self.meta_weights
+ loss = -(
+ train_y * torch.log(result + 1e-10) + (1 - train_y) * torch.log(1 - result + 1e-10)
+ )
+ weighted_loss = torch.sum(self.meta_weights * loss)
+ return weighted_loss
- def outer_loss(self, valid_x, valid_y):
- result = self.forward(valid_x)
- loss = self.criterion(result, valid_y)
- return loss
+ def outer_loss(self, valid_x, valid_y):
+ result = self.forward(valid_x)
+ loss = self.criterion(result, valid_y)
+ return loss
diff --git a/examples/L2R/helper/utils.py b/examples/L2R/helper/utils.py
index 96f469b7..0fb01ad4 100644
--- a/examples/L2R/helper/utils.py
+++ b/examples/L2R/helper/utils.py
@@ -19,161 +19,142 @@
import random
import numpy as np
-import seaborn as sns
import torch
from torch.utils.data import TensorDataset
def get_imbalance_dataset(
- mnist_train,
- mnist_test,
- pos_ratio=0.9,
- ntrain=5000,
- nval=10,
- ntest=500,
- class_0=4,
- class_1=9
+ mnist_train, mnist_test, pos_ratio=0.9, ntrain=5000, nval=10, ntest=500, class_0=4, class_1=9
):
- ratio = 1 - pos_ratio
- ratio_test = 0.5
-
- # In training, we have 10% 4 and 90% 9.
- # In testing, we have 50% 4 and 50% 9.
- x_train = mnist_train.train_data.numpy() / 255.0
- y_train = mnist_train.train_labels.numpy()
- x_test = mnist_test.test_data.numpy() / 255.0
- y_test = mnist_test.test_labels.numpy()
- x_train_0 = x_train[y_train == class_0]
- x_test_0 = x_test[y_test == class_0]
-
- # First shuffle, negative.
- idx = np.arange(x_train_0.shape[0])
- np.random.shuffle(idx)
- x_train_0 = x_train_0[idx]
-
- nval_small_neg = int(np.floor(nval * ratio_test))
- ntrain_small_neg = int(np.floor(ntrain * ratio)) - nval_small_neg
-
- x_val_0 = x_train_0[:nval_small_neg] # 450 4 in validation.
- x_train_0 = x_train_0[nval_small_neg:nval_small_neg + ntrain_small_neg
- ] # 500 4 in training.
-
- print('Number of train negative classes', ntrain_small_neg)
- print('Number of val negative classes', nval_small_neg)
-
- idx = np.arange(x_test_0.shape[0])
- np.random.shuffle(idx)
- x_test_0 = x_test_0[:int(np.floor(ntest * ratio_test))] # 450 4 in testing.
-
- x_train_1 = x_train[y_train == class_1]
- x_test_1 = x_test[y_test == class_1]
-
- # First shuffle, positive.
- idx = np.arange(x_train_1.shape[0])
- np.random.shuffle(idx)
- x_train_1 = x_train_1[idx]
-
- nvalsmall_pos = int(np.floor(nval * (1 - ratio_test)))
- ntrainsmall_pos = int(np.floor(ntrain * (1 - ratio))) - nvalsmall_pos
-
- x_val_1 = x_train_1[:nvalsmall_pos] # 50 9 in validation.
- x_train_1 = x_train_1[nvalsmall_pos:nvalsmall_pos + ntrainsmall_pos
- ] # 4500 9 in training.
-
- idx = np.arange(x_test_1.shape[0])
- np.random.shuffle(idx)
- x_test_1 = x_test_1[idx]
- x_test_1 = x_test_1[:int(np.floor(ntest * (1 - ratio_test)))
- ] # 500 9 in testing.
-
- print('Number of train positive classes', ntrainsmall_pos)
- print('Number of val positive classes', nvalsmall_pos)
-
- y_train_subset = np.concatenate(
- [np.zeros([x_train_0.shape[0]]),
- np.ones([x_train_1.shape[0]])]
- )
- y_val_subset = np.concatenate(
- [np.zeros([x_val_0.shape[0]]),
- np.ones([x_val_1.shape[0]])]
- )
- y_test_subset = np.concatenate(
- [np.zeros([x_test_0.shape[0]]),
- np.ones([x_test_1.shape[0]])]
- )
-
- y_train_pos_subset = np.ones([x_train_1.shape[0]])
- y_train_neg_subset = np.zeros([x_train_0.shape[0]])
-
- x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:,
- None, :, :]
- x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :]
- x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :]
-
- x_train_pos_subset = x_train_1[:, None, :, :]
- x_train_neg_subset = x_train_0[:, None, :, :]
-
- # Final shuffle.
- idx = np.arange(x_train_subset.shape[0])
- np.random.shuffle(idx)
- x_train_subset = x_train_subset[idx].astype(np.float32)
- y_train_subset = y_train_subset[idx].astype(np.float32)
-
- idx = np.arange(x_val_subset.shape[0])
- np.random.shuffle(idx)
- x_val_subset = x_val_subset[idx].astype(np.float32)
- y_val_subset = y_val_subset[idx].astype(np.float32)
-
- idx = np.arange(x_test_subset.shape[0])
- np.random.shuffle(idx)
- x_test_subset = x_test_subset[idx].astype(np.float32)
- y_test_subset = y_test_subset[idx].astype(np.float32)
-
- x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset = torch.tensor(
- x_train_subset
- ), torch.tensor(y_train_subset), torch.tensor(x_val_subset), torch.tensor(
- y_val_subset
- ), torch.tensor(x_test_subset), torch.tensor(y_test_subset)
-
- train_set, val_set, test_set = TensorDataset(
- x_train_subset, y_train_subset
- ), TensorDataset(x_val_subset,
- y_val_subset), TensorDataset(x_test_subset, y_test_subset)
-
- return train_set, val_set, test_set
+ ratio = 1 - pos_ratio
+ ratio_test = 0.5
+
+ # In training, we have 10% 4 and 90% 9.
+ # In testing, we have 50% 4 and 50% 9.
+ x_train = mnist_train.train_data.numpy() / 255.0
+ y_train = mnist_train.train_labels.numpy()
+ x_test = mnist_test.test_data.numpy() / 255.0
+ y_test = mnist_test.test_labels.numpy()
+ x_train_0 = x_train[y_train == class_0]
+ x_test_0 = x_test[y_test == class_0]
+
+ # First shuffle, negative.
+ idx = np.arange(x_train_0.shape[0])
+ np.random.shuffle(idx)
+ x_train_0 = x_train_0[idx]
+
+ nval_small_neg = int(np.floor(nval * ratio_test))
+ ntrain_small_neg = int(np.floor(ntrain * ratio)) - nval_small_neg
+
+ x_val_0 = x_train_0[:nval_small_neg] # 450 4 in validation.
+ x_train_0 = x_train_0[nval_small_neg:nval_small_neg + ntrain_small_neg] # 500 4 in training.
+
+ print('Number of train negative classes', ntrain_small_neg)
+ print('Number of val negative classes', nval_small_neg)
+
+ idx = np.arange(x_test_0.shape[0])
+ np.random.shuffle(idx)
+ x_test_0 = x_test_0[:int(np.floor(ntest * ratio_test))] # 450 4 in testing.
+
+ x_train_1 = x_train[y_train == class_1]
+ x_test_1 = x_test[y_test == class_1]
+
+ # First shuffle, positive.
+ idx = np.arange(x_train_1.shape[0])
+ np.random.shuffle(idx)
+ x_train_1 = x_train_1[idx]
+
+ nvalsmall_pos = int(np.floor(nval * (1 - ratio_test)))
+ ntrainsmall_pos = int(np.floor(ntrain * (1 - ratio))) - nvalsmall_pos
+
+ x_val_1 = x_train_1[:nvalsmall_pos] # 50 9 in validation.
+ x_train_1 = x_train_1[nvalsmall_pos:nvalsmall_pos + ntrainsmall_pos] # 4500 9 in training.
+
+ idx = np.arange(x_test_1.shape[0])
+ np.random.shuffle(idx)
+ x_test_1 = x_test_1[idx]
+ x_test_1 = x_test_1[:int(np.floor(ntest * (1 - ratio_test)))] # 500 9 in testing.
+
+ print('Number of train positive classes', ntrainsmall_pos)
+ print('Number of val positive classes', nvalsmall_pos)
+
+ y_train_subset = np.concatenate([np.zeros([x_train_0.shape[0]]), np.ones([x_train_1.shape[0]])])
+ y_val_subset = np.concatenate([np.zeros([x_val_0.shape[0]]), np.ones([x_val_1.shape[0]])])
+ y_test_subset = np.concatenate([np.zeros([x_test_0.shape[0]]), np.ones([x_test_1.shape[0]])])
+
+ y_train_pos_subset = np.ones([x_train_1.shape[0]])
+ y_train_neg_subset = np.zeros([x_train_0.shape[0]])
+
+ x_train_subset = np.concatenate([x_train_0, x_train_1], axis=0)[:, None, :, :]
+ x_val_subset = np.concatenate([x_val_0, x_val_1], axis=0)[:, None, :, :]
+ x_test_subset = np.concatenate([x_test_0, x_test_1], axis=0)[:, None, :, :]
+
+ x_train_pos_subset = x_train_1[:, None, :, :]
+ x_train_neg_subset = x_train_0[:, None, :, :]
+
+ # Final shuffle.
+ idx = np.arange(x_train_subset.shape[0])
+ np.random.shuffle(idx)
+ x_train_subset = x_train_subset[idx].astype(np.float32)
+ y_train_subset = y_train_subset[idx].astype(np.float32)
+
+ idx = np.arange(x_val_subset.shape[0])
+ np.random.shuffle(idx)
+ x_val_subset = x_val_subset[idx].astype(np.float32)
+ y_val_subset = y_val_subset[idx].astype(np.float32)
+
+ idx = np.arange(x_test_subset.shape[0])
+ np.random.shuffle(idx)
+ x_test_subset = x_test_subset[idx].astype(np.float32)
+ y_test_subset = y_test_subset[idx].astype(np.float32)
+
+ x_train_subset, y_train_subset, x_val_subset, y_val_subset, x_test_subset, y_test_subset = (
+ torch.tensor(x_train_subset), torch.tensor(y_train_subset), torch.tensor(x_val_subset),
+ torch.tensor(y_val_subset), torch.tensor(x_test_subset), torch.tensor(y_test_subset)
+ )
+
+ train_set, val_set, test_set = (
+ TensorDataset(x_train_subset, y_train_subset), TensorDataset(x_val_subset, y_val_subset),
+ TensorDataset(x_test_subset, y_test_subset)
+ )
+
+ return train_set, val_set, test_set
def set_seed(seed, cudnn=True):
- """
+ """
Seed everything we can!
Note that gym environments might need additional seeding (env.seed(seed)),
and num_workers needs to be set to 1.
"""
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.random.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- # note: the below slows down the code but makes it reproducible
- torch.cuda.manual_seed_all(
- seed
- ) # Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
- if cudnn:
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.random.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ # note: the below slows down the code but makes it reproducible
+ # Sets the seed for generating random numbers on all GPUs. It’s safe to
+ # call this function if CUDA is not available; in that case, it is
+ # silently ignored.
+ torch.cuda.manual_seed_all(seed)
+ if cudnn:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
def plot(baseline, l2r):
- import matplotlib.pyplot as plt
- import numpy as np
- import seaborn as sns
- sns.set(style='darkgrid')
- sns.set_theme(style="darkgrid")
- plt.plot(baseline, label='baseline')
- plt.plot(l2r, label='l2r')
- plt.legend()
- plt.ylabel('Test acc')
- plt.xlabel('Epoch')
- plt.title('Comparison between Baseline and L2R')
- plt.savefig('./result.png')
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import seaborn as sns
+
+ sns.set(style='darkgrid')
+ sns.set_theme(style="darkgrid")
+ plt.plot(baseline, label='baseline')
+ plt.plot(l2r, label='l2r')
+ plt.legend()
+ plt.ylabel('Test acc')
+ plt.xlabel('Epoch')
+ plt.title('Comparison between Baseline and L2R')
+ plt.savefig('./result.png')
diff --git a/examples/L2R/train_l2r.py b/examples/L2R/train_l2r.py
index 22deb2ce..c04a90e1 100644
--- a/examples/L2R/train_l2r.py
+++ b/examples/L2R/train_l2r.py
@@ -28,264 +28,227 @@
#
import json
-import os
-import time
import numpy as np
import torch
-import torch.nn as nn
-from helper.argument import parse_args
-from helper.model import LeNet5
-from helper.utils import get_imbalance_dataset, plot, set_seed
-from torch import device
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import MNIST
-import TorchOpt
+import torchopt
+
+from .helper.argument import parse_args
+from .helper.model import LeNet5
+from .helper.utils import get_imbalance_dataset, plot, set_seed
def run_baseline(args, mnist_train, mnist_test):
- print('Run Baseline')
- set_seed(args.seed)
-
- pos_ratio = args.pos_ratio
- ntrain = args.ntrain
- nval = args.nval
- ntest = args.ntest
- epoch = args.epoch
-
- folder = './result/baseline/'
- writer = SummaryWriter('./result/baseline')
- with open('./result/baseline/config.json', 'w') as f:
- json.dump(args.__dict__, f)
-
- args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- train_set, val_set, test_set = get_imbalance_dataset(
- mnist_train,
- mnist_test,
- pos_ratio=pos_ratio,
- ntrain=ntrain,
- nval=nval,
- ntest=ntest
- )
- train_loader = DataLoader(
- train_set, batch_size=args.batch_size, shuffle=True, num_workers=4
- )
- valid_loader = DataLoader(
- val_set, batch_size=args.batch_size, shuffle=True, num_workers=1
- )
- test_loader = DataLoader(
- test_set, batch_size=args.batch_size, shuffle=True, num_workers=1
- )
- model = LeNet5(args).to(args.device)
-
- model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr)
-
- step = 0
- running_train_loss = []
- test_acc_result = []
- for _epoch in range(epoch):
- model.train()
- for idx, (train_x, train_label) in enumerate(train_loader):
- train_x, train_label = train_x.to(args.device
- ), train_label.to(args.device)
- outer_loss = model.outer_loss(train_x, train_label)
-
- model_optimiser.zero_grad()
- outer_loss.backward()
- model_optimiser.step()
-
- running_train_loss.append(outer_loss.item())
- writer.add_scalar('train_loss', outer_loss.item(), step)
-
- if step % 10 == 0 and step > 0:
- running_train_mean = np.mean(np.array(running_train_loss))
- print(
- "EPOCH: {}, BATCH: {}, LOSS: {}".format(
- _epoch, idx, running_train_mean
- )
- )
- writer.add_scalar('running_train_loss', running_train_mean, step)
- running_train_loss = []
-
- step += 1
-
- print('Beginning to Test')
- model.eval()
- train_acc = evaluate(train_loader, model, args)
- test_acc = evaluate(test_loader, model, args)
- model.train()
-
- writer.add_scalar('train_acc', train_acc, _epoch)
- writer.add_scalar('test_acc', test_acc, _epoch)
- test_acc_result.append(test_acc)
- print(
- "EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}".format(
- _epoch, train_acc, test_acc
- )
+ print('Run Baseline')
+ set_seed(args.seed)
+
+ pos_ratio = args.pos_ratio
+ ntrain = args.ntrain
+ nval = args.nval
+ ntest = args.ntest
+ epoch = args.epoch
+
+ folder = './result/baseline/'
+ writer = SummaryWriter('./result/baseline')
+ with open('./result/baseline/config.json', 'w') as f:
+ json.dump(args.__dict__, f)
+
+ args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ train_set, val_set, test_set = get_imbalance_dataset(
+ mnist_train, mnist_test, pos_ratio=pos_ratio, ntrain=ntrain, nval=nval, ntest=ntest
)
- return test_acc_result
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
+ valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
+ test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
+ model = LeNet5(args).to(args.device)
+
+ model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+ step = 0
+ running_train_loss = []
+ test_acc_result = []
+ for _epoch in range(epoch):
+ model.train()
+ for idx, (train_x, train_label) in enumerate(train_loader):
+ train_x, train_label = train_x.to(args.device), train_label.to(args.device)
+ outer_loss = model.outer_loss(train_x, train_label)
+
+ model_optimiser.zero_grad()
+ outer_loss.backward()
+ model_optimiser.step()
+
+ running_train_loss.append(outer_loss.item())
+ writer.add_scalar('train_loss', outer_loss.item(), step)
+
+ if step % 10 == 0 and step > 0:
+ running_train_mean = np.mean(np.array(running_train_loss))
+ print("EPOCH: {}, BATCH: {}, LOSS: {}".format(_epoch, idx, running_train_mean))
+ writer.add_scalar('running_train_loss', running_train_mean, step)
+ running_train_loss = []
+
+ step += 1
+
+ print('Beginning to Test')
+ model.eval()
+ train_acc = evaluate(train_loader, model, args)
+ test_acc = evaluate(test_loader, model, args)
+ model.train()
+
+ writer.add_scalar('train_acc', train_acc, _epoch)
+ writer.add_scalar('test_acc', test_acc, _epoch)
+ test_acc_result.append(test_acc)
+ print("EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}".format(_epoch, train_acc, test_acc))
+ return test_acc_result
def run_L2R(args, mnist_train, mnist_test):
- print('Run L2R')
- set_seed(args.seed)
-
- pos_ratio = args.pos_ratio
- ntrain = args.ntrain
- nval = args.nval
- ntest = args.ntest
- epoch = args.epoch
-
- folder = './result/l2r/'
- writer = SummaryWriter('./result/l2r/log')
- with open('./result/l2r/config.json', 'w') as f:
- json.dump(args.__dict__, f)
-
- args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- train_set, val_set, test_set = get_imbalance_dataset(
- mnist_train,
- mnist_test,
- pos_ratio=pos_ratio,
- ntrain=ntrain,
- nval=nval,
- ntest=ntest
- )
- train_loader = DataLoader(
- train_set, batch_size=args.batch_size, shuffle=True, num_workers=2
- )
- valid_loader = DataLoader(
- val_set, batch_size=args.batch_size, shuffle=True, num_workers=1
- )
- test_loader = DataLoader(
- test_set, batch_size=args.batch_size, shuffle=True, num_workers=1
- )
- model = LeNet5(args).to(args.device)
- model_optimiser = TorchOpt.MetaSGD(model, lr=args.lr)
- real_model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr)
-
- step = 0
- time_bp = 0
- running_valid_loss = []
- valid = iter(valid_loader)
- running_train_loss = []
- test_acc_result = []
- for _epoch in range(epoch):
- model.train()
- for idx, (train_x, train_label) in enumerate(train_loader):
- try:
- valid_x, valid_label = valid.next()
- except:
- valid = iter(valid_loader)
- valid_x, valid_label = valid.next()
- train_x, train_label, valid_x, valid_label = train_x.to(
- args.device
- ), train_label.to(args.device), valid_x.to(args.device
- ), valid_label.to(args.device)
-
- # reset meta-parameter weights
- model.reset_meta(size=train_x.size(0))
-
- net_state_dict = TorchOpt.extract_state_dict(model)
- optim_state_dict = TorchOpt.extract_state_dict(model_optimiser)
-
- for _ in range(1):
- inner_loss = model.inner_loss(train_x, train_label)
- model_optimiser.step(inner_loss)
-
- # caclulate outer_loss, deirve meta-gradient and normalise
- outer_loss = model.outer_loss(valid_x, valid_label)
- model.meta_weights = - \
- torch.autograd.grad(outer_loss, model.meta_weights)[0]
- model.meta_weights = torch.nn.ReLU()(model.meta_weights)
- model.normalise()
-
- # log loss
- running_valid_loss.append(outer_loss.item())
- writer.add_scalar('validation_loss', outer_loss.item(), step)
-
- # reset the model and model optimiser
- TorchOpt.recover_state_dict(model, net_state_dict)
- TorchOpt.recover_state_dict(model_optimiser, optim_state_dict)
-
- # reuse inner_adapt to conduct real update based on learned meta weights
- inner_loss = model.inner_loss(train_x, train_label)
- for _ in range(1):
- inner_loss = model.inner_loss(train_x, train_label)
- real_model_optimiser.zero_grad()
- inner_loss.backward()
- real_model_optimiser.step()
-
- running_train_loss.append(inner_loss.item())
- writer.add_scalar('weighted_train_loss', inner_loss.item(), step)
-
- if step % 10 == 0 and step > 0:
- running_valid_mean = np.mean(np.array(running_valid_loss))
- running_train_mean = np.mean(np.array(running_train_loss))
- print(
- "EPOCH: {}, BATCH: {}, WEIGHTED_TRAIN_LOSS: {}, VALID_LOSS: {}"
- .format(_epoch, idx, running_train_mean, running_valid_mean)
- )
- running_valid_loss = []
- running_train_loss = []
- writer.add_scalar('running_valid_loss', running_valid_mean, step)
- writer.add_scalar('running_train_loss', running_train_mean, step)
-
- step += 1
-
- print('Beginning to Test')
- model.eval()
- train_acc = evaluate(train_loader, model, args)
- test_acc = evaluate(test_loader, model, args)
- model.train()
-
- writer.add_scalar('train_acc', train_acc, _epoch)
- writer.add_scalar('test_acc', test_acc, _epoch)
- test_acc_result.append(test_acc)
- print(
- "EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}".format(
- _epoch, train_acc, test_acc
- )
+ print('Run L2R')
+ set_seed(args.seed)
+
+ pos_ratio = args.pos_ratio
+ ntrain = args.ntrain
+ nval = args.nval
+ ntest = args.ntest
+ epoch = args.epoch
+
+ folder = './result/l2r/'
+ writer = SummaryWriter('./result/l2r/log')
+ with open('./result/l2r/config.json', 'w') as f:
+ json.dump(args.__dict__, f)
+
+ args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ train_set, val_set, test_set = get_imbalance_dataset(
+ mnist_train, mnist_test, pos_ratio=pos_ratio, ntrain=ntrain, nval=nval, ntest=ntest
)
- return test_acc_result
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=2)
+ valid_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
+ test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
+ model = LeNet5(args).to(args.device)
+ model_optimiser = torchopt.MetaSGD(model, lr=args.lr)
+ real_model_optimiser = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+ step = 0
+ time_bp = 0
+ running_valid_loss = []
+ valid = iter(valid_loader)
+ running_train_loss = []
+ test_acc_result = []
+ for _epoch in range(epoch):
+ model.train()
+ for idx, (train_x, train_label) in enumerate(train_loader):
+ try:
+ valid_x, valid_label = valid.next()
+ except BaseException:
+ valid = iter(valid_loader)
+ valid_x, valid_label = valid.next()
+ train_x, train_label, valid_x, valid_label = (
+ train_x.to(args.device), train_label.to(args.device), valid_x.to(args.device),
+ valid_label.to(args.device)
+ )
+
+ # reset meta-parameter weights
+ model.reset_meta(size=train_x.size(0))
+
+ net_state_dict = torchopt.extract_state_dict(model)
+ optim_state_dict = torchopt.extract_state_dict(model_optimiser)
+
+ for _ in range(1):
+ inner_loss = model.inner_loss(train_x, train_label)
+ model_optimiser.step(inner_loss)
+
+ # caclulate outer_loss, deirve meta-gradient and normalise
+ outer_loss = model.outer_loss(valid_x, valid_label)
+ model.meta_weights = - \
+ torch.autograd.grad(outer_loss, model.meta_weights)[0]
+ model.meta_weights = torch.nn.ReLU()(model.meta_weights)
+ model.normalise()
+
+ # log loss
+ running_valid_loss.append(outer_loss.item())
+ writer.add_scalar('validation_loss', outer_loss.item(), step)
+
+ # reset the model and model optimiser
+ torchopt.recover_state_dict(model, net_state_dict)
+ torchopt.recover_state_dict(model_optimiser, optim_state_dict)
+
+ # reuse inner_adapt to conduct real update based on learned meta weights
+ inner_loss = model.inner_loss(train_x, train_label)
+ for _ in range(1):
+ inner_loss = model.inner_loss(train_x, train_label)
+ real_model_optimiser.zero_grad()
+ inner_loss.backward()
+ real_model_optimiser.step()
+
+ running_train_loss.append(inner_loss.item())
+ writer.add_scalar('weighted_train_loss', inner_loss.item(), step)
+
+ if step % 10 == 0 and step > 0:
+ running_valid_mean = np.mean(np.array(running_valid_loss))
+ running_train_mean = np.mean(np.array(running_train_loss))
+ print(
+ "EPOCH: {}, BATCH: {}, WEIGHTED_TRAIN_LOSS: {}, VALID_LOSS: {}".format(
+ _epoch, idx, running_train_mean, running_valid_mean
+ )
+ )
+ running_valid_loss = []
+ running_train_loss = []
+ writer.add_scalar('running_valid_loss', running_valid_mean, step)
+ writer.add_scalar('running_train_loss', running_train_mean, step)
+
+ step += 1
+
+ print('Beginning to Test')
+ model.eval()
+ train_acc = evaluate(train_loader, model, args)
+ test_acc = evaluate(test_loader, model, args)
+ model.train()
+
+ writer.add_scalar('train_acc', train_acc, _epoch)
+ writer.add_scalar('test_acc', test_acc, _epoch)
+ test_acc_result.append(test_acc)
+ print("EPOCH: {}, TRAIN_ACC: {}, TEST_ACC: {}".format(_epoch, train_acc, test_acc))
+ return test_acc_result
def evaluate(data_loader, model, args):
- running_accuracy = 0
- total = 0
- with torch.no_grad():
- for data in data_loader:
- inputs, outputs = data
- inputs, outputs = inputs.to(args.device), outputs.to(args.device)
- predicted = model(inputs)
- predicted[predicted >= 0.5] = 1.0
- predicted[predicted < 0.5] = 0.0
- total += outputs.size(0)
- running_accuracy += (predicted == outputs).sum().item()
-
- accuracy = running_accuracy / total
- return accuracy
+ running_accuracy = 0
+ total = 0
+ with torch.no_grad():
+ for data in data_loader:
+ inputs, outputs = data
+ inputs, outputs = inputs.to(args.device), outputs.to(args.device)
+ predicted = model(inputs)
+ predicted[predicted >= 0.5] = 1.0
+ predicted[predicted < 0.5] = 0.0
+ total += outputs.size(0)
+ running_accuracy += (predicted == outputs).sum().item()
+
+ accuracy = running_accuracy / total
+ return accuracy
def main():
- mnist_train = MNIST(root='./helper/mnist_data', download=True, train=True)
- mnist_test = MNIST(root='./helper/mnist_data', download=True, train=False)
- args = parse_args()
-
- assert args.algo in ['baseline', 'l2r', 'both']
- if args.algo == 'baseline':
- run_baseline(args, mnist_train, mnist_test)
- elif args.algo == 'l2r':
- run_L2R(args, mnist_train, mnist_test)
- else:
- baseline_test_acc = run_baseline(args, mnist_train, mnist_test)
+ mnist_train = MNIST(root='./helper/mnist_data', download=True, train=True)
+ mnist_test = MNIST(root='./helper/mnist_data', download=True, train=False)
args = parse_args()
- l2r_test_acc = run_L2R(args, mnist_train, mnist_test)
- plot(baseline_test_acc, l2r_test_acc)
+
+ assert args.algo in ['baseline', 'l2r', 'both']
+ if args.algo == 'baseline':
+ run_baseline(args, mnist_train, mnist_test)
+ elif args.algo == 'l2r':
+ run_L2R(args, mnist_train, mnist_test)
+ else:
+ baseline_test_acc = run_baseline(args, mnist_train, mnist_test)
+ args = parse_args()
+ l2r_test_acc = run_L2R(args, mnist_train, mnist_test)
+ plot(baseline_test_acc, l2r_test_acc)
if __name__ == '__main__':
- main()
+ main()
diff --git a/examples/LOLA/README.md b/examples/LOLA/README.md
index 8ef37723..1decc337 100755
--- a/examples/LOLA/README.md
+++ b/examples/LOLA/README.md
@@ -1,8 +1,9 @@
# LOLA-examples
-Code On LOLA a in paper [Learning with Opponent-Learning Awareness](https://arxiv.org/abs/1709.04326)] using `TorchOpt`. The LOLA learning rule includes a term that accounts for the impact of one agent's policy on the anticipated parameter update of the other agents. We use `MetaSGD` as the inner-loop optimiser.
+Code On LOLA a in paper [Learning with Opponent-Learning Awareness](https://arxiv.org/abs/1709.04326)] using TorchOpt. The LOLA learning rule includes a term that accounts for the impact of one agent's policy on the anticipated parameter update of the other agents. We use `MetaSGD` as the inner-loop optimiser.
+
+## Usage
-# Usage
```bash
### Run LOLA
python3 lola_dice.py
@@ -11,9 +12,10 @@ python3 lola_dice.py
python3 visualise.py
```
-# Results
+## Results
+
The figure illustrate the experimental result.
+
-

+
-
diff --git a/examples/LOLA/helper/agent.py b/examples/LOLA/helper/agent.py
index 7676cadd..969a04f7 100755
--- a/examples/LOLA/helper/agent.py
+++ b/examples/LOLA/helper/agent.py
@@ -19,38 +19,36 @@
import torch
import torch.nn as nn
-import TorchOpt
+import torchopt
class theta_model(nn.Module):
- def __init__(self, theta):
- super().__init__()
- self.theta = nn.Parameter(torch.tensor(theta.detach(), requires_grad=True))
+ def __init__(self, theta):
+ super().__init__()
+ self.theta = nn.Parameter(torch.tensor(theta.detach(), requires_grad=True))
class Agent():
- def __init__(self, args):
+ def __init__(self, args):
- self.args = args
- # init theta and its optimizer
- self.theta = nn.Parameter(torch.zeros(5, requires_grad=True))
- self.theta_optimizer = torch.optim.Adam((self.theta,), lr=args.lr_out)
+ self.args = args
+ # init theta and its optimizer
+ self.theta = nn.Parameter(torch.zeros(5, requires_grad=True))
+ self.theta_optimizer = torch.optim.Adam((self.theta,), lr=args.lr_out)
- # init values and its optimizer
- self.values = nn.Parameter(torch.zeros(5, requires_grad=True))
- self.value_optimizer = torch.optim.Adam((self.values,), lr=args.lr_v)
+ # init values and its optimizer
+ self.values = nn.Parameter(torch.zeros(5, requires_grad=True))
+ self.value_optimizer = torch.optim.Adam((self.values,), lr=args.lr_v)
- self.set_virtual()
+ self.set_virtual()
- def set_virtual(self):
- self.virtual_theta = theta_model(self.theta)
- self.virtual_optimiser = TorchOpt.MetaSGD(
- self.virtual_theta, lr=self.args.lr_in
- )
+ def set_virtual(self):
+ self.virtual_theta = theta_model(self.theta)
+ self.virtual_optimiser = torchopt.MetaSGD(self.virtual_theta, lr=self.args.lr_in)
- def value_update(self, loss):
- self.value_optimizer.zero_grad()
- loss.backward()
- self.value_optimizer.step()
+ def value_update(self, loss):
+ self.value_optimizer.zero_grad()
+ loss.backward()
+ self.value_optimizer.step()
diff --git a/examples/LOLA/helper/argument.py b/examples/LOLA/helper/argument.py
index 33a29f38..b8e67cc5 100755
--- a/examples/LOLA/helper/argument.py
+++ b/examples/LOLA/helper/argument.py
@@ -17,35 +17,19 @@
def parse_args():
- parser = argparse.ArgumentParser([], description='LOLA')
+ parser = argparse.ArgumentParser([], description='LOLA')
- parser.add_argument('--seed', type=int, default=6666)
- parser.add_argument(
- '--lr_in', type=float, default=0.3, help='Inner Learning rate'
- )
+ parser.add_argument('--seed', type=int, default=6666)
+ parser.add_argument('--lr_in', type=float, default=0.3, help='Inner Learning rate')
- parser.add_argument(
- '--lr_out', type=float, default=0.2, help='Outer learning rate'
- )
- parser.add_argument(
- '--lr_v', type=float, default=0.1, help='Learning rate of value function'
- )
- parser.add_argument(
- '--gamma', type=float, default=0.96, help='Discount factor'
- )
- parser.add_argument(
- '--n_update', type=int, default=100, help='Number of updates'
- )
- parser.add_argument(
- '--n_lookaheads', type=int, default=1, help='Number of updates'
- )
- parser.add_argument(
- '--len_rollout', type=int, default=150, help='Length of IPD'
- )
- parser.add_argument(
- '--batch_size', type=int, default=1024, help='Natch size'
- )
- parser.add_argument('--use_baseline', action='store_false', default=True)
+ parser.add_argument('--lr_out', type=float, default=0.2, help='Outer learning rate')
+ parser.add_argument('--lr_v', type=float, default=0.1, help='Learning rate of value function')
+ parser.add_argument('--gamma', type=float, default=0.96, help='Discount factor')
+ parser.add_argument('--n_update', type=int, default=100, help='Number of updates')
+ parser.add_argument('--n_lookaheads', type=int, default=1, help='Number of updates')
+ parser.add_argument('--len_rollout', type=int, default=150, help='Length of IPD')
+ parser.add_argument('--batch_size', type=int, default=1024, help='Natch size')
+ parser.add_argument('--use_baseline', action='store_false', default=True)
- args = parser.parse_args()
- return args
+ args = parser.parse_args()
+ return args
diff --git a/examples/LOLA/helper/env.py b/examples/LOLA/helper/env.py
index bb72c5b0..df4522f6 100755
--- a/examples/LOLA/helper/env.py
+++ b/examples/LOLA/helper/env.py
@@ -22,79 +22,75 @@
class OneHot(gym.Space):
- """
+ """
One-hot space. Used as the observation space.
"""
- def __init__(self, n):
- self.n = n
+ def __init__(self, n):
+ self.n = n
- def sample(self):
- return np.random.multinomial(1, [1. / self.n] * self.n)
+ def sample(self):
+ return np.random.multinomial(1, [1. / self.n] * self.n)
- def contains(self, x):
- return isinstance(x, np.ndarray) and \
- x.shape == (self.n, ) and \
- np.all(np.logical_or(x == 0, x == 1)) and \
- np.sum(x) == 1
+ def contains(self, x):
+ return isinstance(x, np.ndarray) and \
+ x.shape == (self.n, ) and \
+ np.all(np.logical_or(x == 0, x == 1)) and \
+ np.sum(x) == 1
- @property
- def shape(self):
- return (self.n,)
+ @property
+ def shape(self):
+ return (self.n,)
- def __repr__(self):
- return "OneHot(%d)" % self.n
+ def __repr__(self):
+ return "OneHot(%d)" % self.n
- def __eq__(self, other):
- return self.n == other.n
+ def __eq__(self, other):
+ return self.n == other.n
class IPD(gym.Env):
- """
+ """
A two-agent vectorized environment.
Possible actions for each agent are (C)ooperate and (D)efect.
"""
- # Possible actions
- NUM_AGENTS = 2
- NUM_ACTIONS = 2
- NUM_STATES = 5
-
- def __init__(self, max_steps, batch_size=1):
- self.max_steps = max_steps
- self.batch_size = batch_size
- self.payout_mat = np.array([[-2, 0], [-3, -1]])
- self.states = np.array([[1, 2], [3, 4]])
-
- self.action_space = Tuple(
- [Discrete(self.NUM_ACTIONS) for _ in range(self.NUM_AGENTS)]
- )
- self.observation_space = Tuple(
- [OneHot(self.NUM_STATES) for _ in range(self.NUM_AGENTS)]
- )
- self.available_actions = [
- np.ones((batch_size, self.NUM_ACTIONS), dtype=int)
- for _ in range(self.NUM_AGENTS)
- ]
-
- self.step_count = None
-
- def reset(self):
- self.step_count = 0
- init_state = np.zeros(self.batch_size)
- observation = [init_state, init_state]
- info = [{'available_actions': aa} for aa in self.available_actions]
- return observation, info
-
- def step(self, action):
- ac0, ac1 = action
- self.step_count += 1
-
- r0 = self.payout_mat[ac0, ac1]
- r1 = self.payout_mat[ac1, ac0]
- s0 = self.states[ac0, ac1]
- s1 = self.states[ac1, ac0]
- observation = [s0, s1]
- reward = [r0, r1]
- done = (self.step_count == self.max_steps)
- info = [{'available_actions': aa} for aa in self.available_actions]
- return observation, reward, done, info
+
+ # Possible actions
+ NUM_AGENTS = 2
+ NUM_ACTIONS = 2
+ NUM_STATES = 5
+
+ def __init__(self, max_steps, batch_size=1):
+ self.max_steps = max_steps
+ self.batch_size = batch_size
+ self.payout_mat = np.array([[-2, 0], [-3, -1]])
+ self.states = np.array([[1, 2], [3, 4]])
+
+ self.action_space = Tuple([Discrete(self.NUM_ACTIONS) for _ in range(self.NUM_AGENTS)])
+ self.observation_space = Tuple([OneHot(self.NUM_STATES) for _ in range(self.NUM_AGENTS)])
+ self.available_actions = [
+ np.ones((batch_size, self.NUM_ACTIONS), dtype=int) for _ in range(self.NUM_AGENTS)
+ ]
+
+ self.step_count = None
+
+ def reset(self):
+ self.step_count = 0
+ init_state = np.zeros(self.batch_size)
+ observation = [init_state, init_state]
+ info = [{'available_actions': aa} for aa in self.available_actions]
+ return observation, info
+
+ def step(self, action):
+ ac0, ac1 = action
+ self.step_count += 1
+
+ r0 = self.payout_mat[ac0, ac1]
+ r1 = self.payout_mat[ac1, ac0]
+ s0 = self.states[ac0, ac1]
+ s1 = self.states[ac1, ac0]
+ observation = [s0, s1]
+ reward = [r0, r1]
+ done = (self.step_count == self.max_steps)
+ info = [{'available_actions': aa} for aa in self.available_actions]
+ return observation, reward, done, info
diff --git a/examples/LOLA/helper/utils.py b/examples/LOLA/helper/utils.py
index 30b8cf51..6b487a40 100755
--- a/examples/LOLA/helper/utils.py
+++ b/examples/LOLA/helper/utils.py
@@ -23,101 +23,97 @@
# evaluate the policy
def step(ipd, theta1, theta2, values1, values2, args):
- # just to evaluate progress:
- (s1, s2), _ = ipd.reset()
- score1 = 0
- score2 = 0
- for t in range(args.len_rollout):
- a1, lp1, v1 = act(s1, theta1, values1)
- a2, lp2, v2 = act(s2, theta2, values2)
- (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
- # cumulate scores
- score1 += np.mean(r1) / float(args.len_rollout)
- score2 += np.mean(r2) / float(args.len_rollout)
- return (score1, score2)
+ # just to evaluate progress:
+ (s1, s2), _ = ipd.reset()
+ score1 = 0
+ score2 = 0
+ for t in range(args.len_rollout):
+ a1, lp1, v1 = act(s1, theta1, values1)
+ a2, lp2, v2 = act(s2, theta2, values2)
+ (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
+ # cumulate scores
+ score1 += np.mean(r1) / float(args.len_rollout)
+ score2 += np.mean(r2) / float(args.len_rollout)
+ return (score1, score2)
# dice operator
def magic_box(x):
- return torch.exp(x - x.detach())
+ return torch.exp(x - x.detach())
# replay buffer
class Memory():
- def __init__(self, args):
- self.self_logprobs = []
- self.other_logprobs = []
- self.values = []
- self.rewards = []
- self.args = args
-
- def add(self, lp, other_lp, v, r):
- self.self_logprobs.append(lp)
- self.other_logprobs.append(other_lp)
- self.values.append(v)
- self.rewards.append(r)
-
- def dice_objective(self, use_baseline=True):
- self_logprobs = torch.stack(self.self_logprobs, dim=1)
- other_logprobs = torch.stack(self.other_logprobs, dim=1)
- values = torch.stack(self.values, dim=1)
- rewards = torch.stack(self.rewards, dim=1)
-
- # apply discount:
- cum_discount = torch.cumprod(
- self.args.gamma * torch.ones(*rewards.size()), dim=1
- ) / self.args.gamma
- discounted_rewards = rewards * cum_discount
- discounted_values = values * cum_discount
-
- # stochastics nodes involved in rewards dependencies:
- dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)
-
- # logprob of each stochastic nodes:
- stochastic_nodes = self_logprobs + other_logprobs
-
- # dice objective:
- dice_objective = torch.mean(
- torch.sum(magic_box(dependencies) * discounted_rewards, dim=1)
- )
-
- if use_baseline:
- # variance_reduction:
- baseline_term = torch.mean(
- torch.sum(
- (1 - magic_box(stochastic_nodes)) * discounted_values, dim=1
- )
- )
- dice_objective = dice_objective + baseline_term
-
- return -dice_objective # want to minimize -objective
-
- def value_loss(self):
- values = torch.stack(self.values, dim=1)
- rewards = torch.stack(self.rewards, dim=1)
- return torch.mean((rewards - values)**2)
+ def __init__(self, args):
+ self.self_logprobs = []
+ self.other_logprobs = []
+ self.values = []
+ self.rewards = []
+ self.args = args
+
+ def add(self, lp, other_lp, v, r):
+ self.self_logprobs.append(lp)
+ self.other_logprobs.append(other_lp)
+ self.values.append(v)
+ self.rewards.append(r)
+
+ def dice_objective(self, use_baseline=True):
+ self_logprobs = torch.stack(self.self_logprobs, dim=1)
+ other_logprobs = torch.stack(self.other_logprobs, dim=1)
+ values = torch.stack(self.values, dim=1)
+ rewards = torch.stack(self.rewards, dim=1)
+
+ # apply discount:
+ cum_discount = torch.cumprod(
+ self.args.gamma * torch.ones(*rewards.size()), dim=1
+ ) / self.args.gamma
+ discounted_rewards = rewards * cum_discount
+ discounted_values = values * cum_discount
+
+ # stochastics nodes involved in rewards dependencies:
+ dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)
+
+ # logprob of each stochastic nodes:
+ stochastic_nodes = self_logprobs + other_logprobs
+
+ # dice objective:
+ dice_objective = torch.mean(torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))
+
+ if use_baseline:
+ # variance_reduction:
+ baseline_term = torch.mean(
+ torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1)
+ )
+ dice_objective = dice_objective + baseline_term
+
+ return -dice_objective # want to minimize -objective
+
+ def value_loss(self):
+ values = torch.stack(self.values, dim=1)
+ rewards = torch.stack(self.rewards, dim=1)
+ return torch.mean((rewards - values)**2)
def act(batch_states, theta, values):
- batch_states = torch.from_numpy(batch_states).long()
- probs = torch.sigmoid(theta)[batch_states]
- m = Bernoulli(1 - probs)
- actions = m.sample()
- log_probs_actions = m.log_prob(actions)
- return actions.numpy().astype(int), log_probs_actions, values[batch_states]
+ batch_states = torch.from_numpy(batch_states).long()
+ probs = torch.sigmoid(theta)[batch_states]
+ m = Bernoulli(1 - probs)
+ actions = m.sample()
+ log_probs_actions = m.log_prob(actions)
+ return actions.numpy().astype(int), log_probs_actions, values[batch_states]
def sample(ipd, policy, value, args):
- theta1, theta2 = policy
- value1, value2 = value
- (s1, s2), _ = ipd.reset()
- memory_agent1 = Memory(args)
- memory_agent2 = Memory(args)
- for t in range(args.len_rollout):
- a1, lp1, v1 = act(s1, theta1, value1)
- a2, lp2, v2 = act(s2, theta2, value2)
- (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
- memory_agent1.add(lp1, lp2, v1, torch.from_numpy(r1).float())
- memory_agent2.add(lp2, lp1, v2, torch.from_numpy(r2).float())
- return memory_agent1, memory_agent2
+ theta1, theta2 = policy
+ value1, value2 = value
+ (s1, s2), _ = ipd.reset()
+ memory_agent1 = Memory(args)
+ memory_agent2 = Memory(args)
+ for t in range(args.len_rollout):
+ a1, lp1, v1 = act(s1, theta1, value1)
+ a2, lp2, v2 = act(s2, theta2, value2)
+ (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2))
+ memory_agent1.add(lp1, lp2, v1, torch.from_numpy(r1).float())
+ memory_agent2.add(lp2, lp1, v2, torch.from_numpy(r2).float())
+ return memory_agent1, memory_agent2
diff --git a/examples/LOLA/lola_dice.py b/examples/LOLA/lola_dice.py
index f5a112da..82d585d6 100755
--- a/examples/LOLA/lola_dice.py
+++ b/examples/LOLA/lola_dice.py
@@ -16,105 +16,94 @@
# https://github.com/alexis-jacq/LOLA_DiCE
# ==============================================================================
-from copy import deepcopy
-
-import matplotlib.pyplot as plt
import numpy as np
import torch
-import torch.nn as nn
-from helper.agent import Agent
-from helper.argument import parse_args
-from helper.env import IPD
-from helper.utils import sample, step
-from torch.distributions import Bernoulli
-import TorchOpt
+from .helper.agent import Agent
+from .helper.argument import parse_args
+from .helper.env import IPD
+from .helper.utils import sample, step
def main(args):
- ipd = IPD(args.len_rollout, args.batch_size)
- agent1, agent2 = Agent(args), Agent(args)
- agent1_copy, agent2_copy = Agent(args), Agent(args)
- n_lookaheads = args.n_lookaheads
- joint_scores = []
- print("start iterations with", n_lookaheads, "lookaheads:")
-
- for update in range(args.n_update):
- # reset virtual update
- agent1.set_virtual()
- agent2.set_virtual()
-
- # agent 2 assumes that agent 1 conducts n-step lookahead
- for _ in range(n_lookaheads):
- memory1, memory2 = sample(
- ipd, [agent1.virtual_theta.theta, agent2.theta],
- [agent1.values, agent2.values], args
- )
- inner_loss = memory1.dice_objective(use_baseline=args.use_baseline)
- agent1.virtual_optimiser.step(inner_loss)
-
- # agent 1 assumes that agent 2 conducts n-step lookahead
- for _ in range(n_lookaheads):
- memory1, memory2 = sample(
- ipd, [agent1.theta, agent2.virtual_theta.theta],
- [agent1.values, agent2.values], args
- )
- inner_loss = memory2.dice_objective(use_baseline=args.use_baseline)
- agent2.virtual_optimiser.step(inner_loss)
-
- # update agent 1
- memory1, memory2 = sample(
- ipd, [agent1.theta, agent2.virtual_theta.theta],
- [agent1.values, agent2.values], args
- )
- outer_loss = memory1.dice_objective(use_baseline=args.use_baseline)
- agent1.theta_optimizer.zero_grad()
- outer_loss.backward(retain_graph=True)
- agent1.theta_optimizer.step()
-
- # update agent 1 value function
- v_loss = memory1.value_loss()
- agent1.value_update(v_loss)
-
- # update agent 2
- memory1, memory2 = sample(
- ipd, [agent1.virtual_theta.theta, agent2.theta],
- [agent1.values, agent2.values], args
- )
- outer_loss = memory2.dice_objective(use_baseline=args.use_baseline)
- agent2.theta_optimizer.zero_grad()
- outer_loss.backward(retain_graph=True)
- agent2.theta_optimizer.step()
-
- # update agent 2 value function
- v_loss = memory2.value_loss()
- agent2.value_update(v_loss)
-
- # evaluate progress:
- score = step(
- ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args
- )
- joint_scores.append(0.5 * (score[0] + score[1]))
-
- # print
- if update % 10 == 0:
- p1 = [p.item() for p in torch.sigmoid(agent1.theta)]
- p2 = [p.item() for p in torch.sigmoid(agent2.theta)]
- print(
- 'update', update, 'score (%.3f,%.3f)' % (score[0], score[1]),
- 'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}' %
- (p1[0], p1[1], p1[2], p1[3], p1[4]),
- ' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' %
- (p2[0], p2[1], p2[2], p2[3], p2[4])
- )
-
- return joint_scores
+ ipd = IPD(args.len_rollout, args.batch_size)
+ agent1, agent2 = Agent(args), Agent(args)
+ agent1_copy, agent2_copy = Agent(args), Agent(args)
+ n_lookaheads = args.n_lookaheads
+ joint_scores = []
+ print("start iterations with", n_lookaheads, "lookaheads:")
+
+ for update in range(args.n_update):
+ # reset virtual update
+ agent1.set_virtual()
+ agent2.set_virtual()
+
+ # agent 2 assumes that agent 1 conducts n-step lookahead
+ for _ in range(n_lookaheads):
+ memory1, memory2 = sample(
+ ipd, [agent1.virtual_theta.theta, agent2.theta], [agent1.values, agent2.values],
+ args
+ )
+ inner_loss = memory1.dice_objective(use_baseline=args.use_baseline)
+ agent1.virtual_optimiser.step(inner_loss)
+
+ # agent 1 assumes that agent 2 conducts n-step lookahead
+ for _ in range(n_lookaheads):
+ memory1, memory2 = sample(
+ ipd, [agent1.theta, agent2.virtual_theta.theta], [agent1.values, agent2.values],
+ args
+ )
+ inner_loss = memory2.dice_objective(use_baseline=args.use_baseline)
+ agent2.virtual_optimiser.step(inner_loss)
+
+ # update agent 1
+ memory1, memory2 = sample(
+ ipd, [agent1.theta, agent2.virtual_theta.theta], [agent1.values, agent2.values], args
+ )
+ outer_loss = memory1.dice_objective(use_baseline=args.use_baseline)
+ agent1.theta_optimizer.zero_grad()
+ outer_loss.backward(retain_graph=True)
+ agent1.theta_optimizer.step()
+
+ # update agent 1 value function
+ v_loss = memory1.value_loss()
+ agent1.value_update(v_loss)
+
+ # update agent 2
+ memory1, memory2 = sample(
+ ipd, [agent1.virtual_theta.theta, agent2.theta], [agent1.values, agent2.values], args
+ )
+ outer_loss = memory2.dice_objective(use_baseline=args.use_baseline)
+ agent2.theta_optimizer.zero_grad()
+ outer_loss.backward(retain_graph=True)
+ agent2.theta_optimizer.step()
+
+ # update agent 2 value function
+ v_loss = memory2.value_loss()
+ agent2.value_update(v_loss)
+
+ # evaluate progress:
+ score = step(ipd, agent1.theta, agent2.theta, agent1.values, agent2.values, args)
+ joint_scores.append(0.5 * (score[0] + score[1]))
+
+ # print
+ if update % 10 == 0:
+ p1 = [p.item() for p in torch.sigmoid(agent1.theta)]
+ p2 = [p.item() for p in torch.sigmoid(agent2.theta)]
+ print(
+ 'update', update, 'score (%.3f,%.3f)' % (score[0], score[1]),
+ 'policy (agent1) = {%.3f, %.3f, %.3f, %.3f, %.3f}' %
+ (p1[0], p1[1], p1[2], p1[3], p1[4]),
+ ' (agent2) = {%.3f, %.3f, %.3f, %.3f, %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4])
+ )
+
+ return joint_scores
if __name__ == "__main__":
- args = parse_args()
- joint_score = dict()
- for nla in range(3):
- args.n_lookaheads = nla
- joint_score[nla] = main(args)
- np.save('result.npy', joint_score)
+ args = parse_args()
+ joint_score = dict()
+ for nla in range(3):
+ args.n_lookaheads = nla
+ joint_score[nla] = main(args)
+ np.save('result.npy', joint_score)
diff --git a/examples/LOLA/visualise.py b/examples/LOLA/visualise.py
index de71afef..2640f6a7 100755
--- a/examples/LOLA/visualise.py
+++ b/examples/LOLA/visualise.py
@@ -19,17 +19,17 @@
def plot(file):
- data = np.load('result.npy', allow_pickle=True).tolist()
- sns.set(style='darkgrid')
- sns.set_theme(style="darkgrid")
- for step in range(3):
- plt.plot(data[step], label='Step ' + str(step))
- plt.legend()
- plt.xlabel('Iteartions', fontsize=20)
- plt.ylabel('Joint score', fontsize=20)
- plt.savefig('./result.png')
+ data = np.load('result.npy', allow_pickle=True).tolist()
+ sns.set(style='darkgrid')
+ sns.set_theme(style="darkgrid")
+ for step in range(3):
+ plt.plot(data[step], label='Step ' + str(step))
+ plt.legend()
+ plt.xlabel('Iteartions', fontsize=20)
+ plt.ylabel('Joint score', fontsize=20)
+ plt.savefig('./result.png')
# plot progress:
if __name__ == "__main__":
- plot('result.npy')
+ plot('result.npy')
diff --git a/examples/MAML-RL/README.md b/examples/MAML-RL/README.md
index 26a80200..d99738e3 100755
--- a/examples/MAML-RL/README.md
+++ b/examples/MAML-RL/README.md
@@ -1,16 +1,20 @@
# Reinforcement learning with Model-Agnostic Meta-Learning (MAML)
-Code on Tabular MDP example in paper *Model-Agnostic Meta-Learning* [[MAML](https://arxiv.org/abs/1703.03400)] using `TorchOpt`. The idea of MAML is to learn the initial parameters of an agent's policy so that the agent can rapidly adapt to new environments with a limited number of policy-gradient updates. We use `MetaSGD` as the inner-loop optimiser.
+Code on Tabular MDP example in paper *Model-Agnostic Meta-Learning* [[MAML](https://arxiv.org/abs/1703.03400)] using TorchOpt. The idea of MAML is to learn the initial parameters of an agent's policy so that the agent can rapidly adapt to new environments with a limited number of policy-gradient updates. We use `MetaSGD` as the inner-loop optimiser.
+
+## Usage
-# Usage
Specify the seed to train.
+
```bash
### Run MAML
python run_MAML.py --seed 1
```
-# Results
+## Results
+
The training curve and testing curve between initial policy and adapted policy validate the effectiveness of algorithms.
+
-

+
diff --git a/examples/MAML-RL/helpers/Tabular_mdp.py b/examples/MAML-RL/helpers/Tabular_mdp.py
index 32a9d929..1df07599 100644
--- a/examples/MAML-RL/helpers/Tabular_mdp.py
+++ b/examples/MAML-RL/helpers/Tabular_mdp.py
@@ -20,18 +20,17 @@
import numpy as np
from gym import spaces
from gym.utils import seeding
-from gym.wrappers.time_limit import TimeLimit
class TabularMDPEnv(gym.Env):
- """Tabular MDP problems, as described in [1].
-
- At each time step, the agent chooses one of `num_actions` actions, say `i`,
- receives a reward sampled from a Normal distribution with mean `m_i` and
- variance 1 (fixed across all tasks), and reaches a new state following the
- dynamics of the Markov Decision Process (MDP). The tabular MDP tasks are
- generated by sampling the mean rewards from a Normal distribution with mean
- 1 and variance 1, and sampling the transition probabilities from a uniform
+ """Tabular MDP problems, as described in [1].
+
+ At each time step, the agent chooses one of `num_actions` actions, say `i`,
+ receives a reward sampled from a Normal distribution with mean `m_i` and
+ variance 1 (fixed across all tasks), and reaches a new state following the
+ dynamics of the Markov Decision Process (MDP). The tabular MDP tasks are
+ generated by sampling the mean rewards from a Normal distribution with mean
+ 1 and variance 1, and sampling the transition probabilities from a uniform
Dirichlet distribution (ie. with parameter 1).
[1] Yan Duan, John Schulman, Xi Chen, Peter L. Bartlett, Ilya Sutskever,
@@ -39,83 +38,76 @@ class TabularMDPEnv(gym.Env):
Learning", 2016 (https://arxiv.org/abs/1611.02779)
"""
- def __init__(
- self, num_states, num_actions, max_episode_steps, seed, task={}
- ):
- super(TabularMDPEnv, self).__init__()
- self.max_episode_steps = max_episode_steps
- self.num_states = num_states
- self.num_actions = num_actions
-
- self.action_space = spaces.Discrete(num_actions)
- self.observation_space = spaces.Box(
- low=0.0, high=1.0, shape=(num_states,), dtype=np.float32
- )
-
- self._task = task
- self._transitions = task.get(
- 'transitions',
- np.full(
- (num_states, num_actions, num_states),
- 1.0 / num_states,
- dtype=np.float32
- )
- )
- self._rewards_mean = task.get(
- 'rewards_mean', np.zeros((num_states, num_actions), dtype=np.float32)
- )
- self._state = 0
- self._elapsed_steps = None
-
- self.seed(seed)
-
- def seed(self, seed=None):
- self.np_random, seed = seeding.np_random(seed)
- return [seed]
-
- def sample_tasks(self, num_tasks):
- transitions = self.np_random.dirichlet(
- np.ones(self.num_states),
- size=(num_tasks, self.num_states, self.num_actions)
- )
- rewards_mean = self.np_random.normal(
- 1.0, 1.0, size=(num_tasks, self.num_states, self.num_actions)
- )
- tasks = [
- {
- 'transitions': transition,
- 'rewards_mean': reward_mean
- } for (transition, reward_mean) in zip(transitions, rewards_mean)
- ]
- return tasks
-
- def reset_task(self, task):
- self._task = task
- self._transitions = task['transitions']
- self._rewards_mean = task['rewards_mean']
-
- def reset(self):
- # From [1]: "an episode always starts on the first state"
- self._state = 0
- observation = np.zeros(self.num_states, dtype=np.float32)
- observation[self._state] = 1.0
- self._elapsed_steps = 0
-
- return observation
-
- def step(self, action):
- assert self.action_space.contains(action)
- mean = self._rewards_mean[self._state, action]
- reward = self.np_random.normal(mean, 1.0)
-
- self._state = self.np_random.choice(
- self.num_states, p=self._transitions[self._state, action]
- )
- observation = np.zeros(self.num_states, dtype=np.float32)
- observation[self._state] = 1.0
- self._elapsed_steps += 1
- if self._elapsed_steps >= self.max_episode_steps:
- done = True
- else:
- done = False
- return observation, reward, done, {'task': self._task}
+ def __init__(self, num_states, num_actions, max_episode_steps, seed, task={}):
+ super(TabularMDPEnv, self).__init__()
+ self.max_episode_steps = max_episode_steps
+ self.num_states = num_states
+ self.num_actions = num_actions
+
+ self.action_space = spaces.Discrete(num_actions)
+ self.observation_space = spaces.Box(
+ low=0.0, high=1.0, shape=(num_states,), dtype=np.float32
+ )
+
+ self._task = task
+ self._transitions = task.get(
+ 'transitions',
+ np.full((num_states, num_actions, num_states), 1.0 / num_states, dtype=np.float32)
+ )
+ self._rewards_mean = task.get(
+ 'rewards_mean', np.zeros((num_states, num_actions), dtype=np.float32)
+ )
+ self._state = 0
+ self._elapsed_steps = None
+
+ self.seed(seed)
+
+ def seed(self, seed=None):
+ self.np_random, seed = seeding.np_random(seed)
+ return [seed]
+
+ def sample_tasks(self, num_tasks):
+ transitions = self.np_random.dirichlet(
+ np.ones(self.num_states), size=(num_tasks, self.num_states, self.num_actions)
+ )
+ rewards_mean = self.np_random.normal(
+ 1.0, 1.0, size=(num_tasks, self.num_states, self.num_actions)
+ )
+ tasks = [
+ {
+ 'transitions': transition,
+ 'rewards_mean': reward_mean
+ } for (transition, reward_mean) in zip(transitions, rewards_mean)
+ ]
+ return tasks
+
+ def reset_task(self, task):
+ self._task = task
+ self._transitions = task['transitions']
+ self._rewards_mean = task['rewards_mean']
+
+ def reset(self):
+ # From [1]: "an episode always starts on the first state"
+ self._state = 0
+ observation = np.zeros(self.num_states, dtype=np.float32)
+ observation[self._state] = 1.0
+ self._elapsed_steps = 0
+
+ return observation
+
+ def step(self, action):
+ assert self.action_space.contains(action)
+ mean = self._rewards_mean[self._state, action]
+ reward = self.np_random.normal(mean, 1.0)
+
+ self._state = self.np_random.choice(
+ self.num_states, p=self._transitions[self._state, action]
+ )
+ observation = np.zeros(self.num_states, dtype=np.float32)
+ observation[self._state] = 1.0
+ self._elapsed_steps += 1
+ if self._elapsed_steps >= self.max_episode_steps:
+ done = True
+ else:
+ done = False
+ return observation, reward, done, {'task': self._task}
diff --git a/examples/MAML-RL/helpers/__init__.py b/examples/MAML-RL/helpers/__init__.py
index c3fee90d..e8761adc 100644
--- a/examples/MAML-RL/helpers/__init__.py
+++ b/examples/MAML-RL/helpers/__init__.py
@@ -19,12 +19,12 @@
from gym.envs.registration import register
register(
- 'TabularMDP-v0',
- entry_point='helpers.Tabular_mdp:TabularMDPEnv',
- kwargs={
- 'num_states': 10,
- 'num_actions': 5,
- 'max_episode_steps': 10,
- 'seed': 1
- }
+ 'TabularMDP-v0',
+ entry_point='helpers.Tabular_mdp:TabularMDPEnv',
+ kwargs={
+ 'num_states': 10,
+ 'num_actions': 5,
+ 'max_episode_steps': 10,
+ 'seed': 1
+ }
)
diff --git a/examples/MAML-RL/helpers/policy.py b/examples/MAML-RL/helpers/policy.py
index 54ee3f5c..66ab1fa3 100644
--- a/examples/MAML-RL/helpers/policy.py
+++ b/examples/MAML-RL/helpers/policy.py
@@ -22,28 +22,28 @@
class CategoricalMLPPolicy(nn.Module):
- """Policy network based on a multi-layer perceptron (MLP), with a
- `Categorical` distribution output. This policy network can be used on tasks
- with discrete action spaces (eg. `TabularMDPEnv`).
+ """Policy network based on a multi-layer perceptron (MLP), with a
+ `Categorical` distribution output. This policy network can be used on tasks
+ with discrete action spaces (eg. `TabularMDPEnv`).
"""
- def __init__(
- self,
- input_size,
- output_size,
- ):
- super(CategoricalMLPPolicy, self).__init__()
- self.torso = nn.Sequential(
- nn.Linear(input_size, 32),
- nn.ReLU(),
- nn.Linear(32, 32),
- nn.ReLU(),
- )
- self.policy_head = nn.Linear(32, output_size)
- self.value_head = nn.Linear(32, 1)
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ ):
+ super(CategoricalMLPPolicy, self).__init__()
+ self.torso = nn.Sequential(
+ nn.Linear(input_size, 32),
+ nn.ReLU(),
+ nn.Linear(32, 32),
+ nn.ReLU(),
+ )
+ self.policy_head = nn.Linear(32, output_size)
+ self.value_head = nn.Linear(32, 1)
- def forward(self, inputs, params=None):
- embedding = self.torso(inputs)
- logits = self.policy_head(embedding)
- values = self.value_head(embedding)
- return Categorical(logits=logits), values
+ def forward(self, inputs, params=None):
+ embedding = self.torso(inputs)
+ logits = self.policy_head(embedding)
+ values = self.value_head(embedding)
+ return Categorical(logits=logits), values
diff --git a/examples/MAML-RL/run_MAML.py b/examples/MAML-RL/run_MAML.py
index 1507e8bc..252f25e0 100644
--- a/examples/MAML-RL/run_MAML.py
+++ b/examples/MAML-RL/run_MAML.py
@@ -20,9 +20,10 @@
import numpy as np
import torch
import torch.optim as optim
-from helpers.policy import CategoricalMLPPolicy
-import TorchOpt
+import torchopt
+
+from .helpers.policy import CategoricalMLPPolicy
TASK_NUM = 40
TRAJ_NUM = 20
@@ -39,173 +40,161 @@
class Traj(NamedTuple):
- obs: np.ndarray
- acs: np.ndarray
- next_obs: np.ndarray
- rews: np.ndarray
- gammas: np.ndarray
+ obs: np.ndarray
+ acs: np.ndarray
+ next_obs: np.ndarray
+ rews: np.ndarray
+ gammas: np.ndarray
def sample_traj(env, task, policy):
- env.reset_task(task)
- obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
- next_obs_buf = np.zeros(
- shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32
- )
- acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8)
- rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
- gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
- with torch.no_grad():
- for batch in range(TRAJ_NUM):
- ob = env.reset()
- for step in range(TRAJ_LEN):
- ob_tensor = torch.from_numpy(ob)
- pi, _ = policy(ob_tensor)
- ac_tensor = pi.sample()
- ac = ac_tensor.cpu().numpy()
- next_ob, rew, done, info = env.step(ac)
-
- obs_buf[step][batch] = ob
- next_obs_buf[step][batch] = next_ob
- acs_buf[step][batch] = ac
- rews_buf[step][batch] = rew
- gammas_buf[step][batch] = done * GAMMA
- ob = next_ob
- return Traj(
- obs=obs_buf,
- acs=acs_buf,
- next_obs=next_obs_buf,
- rews=rews_buf,
- gammas=gammas_buf
- )
+ env.reset_task(task)
+ obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
+ next_obs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM, STATE_DIM), dtype=np.float32)
+ acs_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.int8)
+ rews_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
+ gammas_buf = np.zeros(shape=(TRAJ_LEN, TRAJ_NUM), dtype=np.float32)
+ with torch.no_grad():
+ for batch in range(TRAJ_NUM):
+ ob = env.reset()
+ for step in range(TRAJ_LEN):
+ ob_tensor = torch.from_numpy(ob)
+ pi, _ = policy(ob_tensor)
+ ac_tensor = pi.sample()
+ ac = ac_tensor.cpu().numpy()
+ next_ob, rew, done, info = env.step(ac)
+
+ obs_buf[step][batch] = ob
+ next_obs_buf[step][batch] = next_ob
+ acs_buf[step][batch] = ac
+ rews_buf[step][batch] = rew
+ gammas_buf[step][batch] = done * GAMMA
+ ob = next_ob
+ return Traj(obs=obs_buf, acs=acs_buf, next_obs=next_obs_buf, rews=rews_buf, gammas=gammas_buf)
def a2c_loss(traj, policy, value_coef):
- lambdas = np.ones_like(traj.gammas) * LAMBDA
- _, next_values = policy(torch.from_numpy(traj.next_obs))
- next_values = torch.squeeze(next_values, -1).detach().numpy()
- # Work backwards to compute `G_{T-1}`, ..., `G_0`.
- returns = []
- g = next_values[-1, :]
- for i in reversed(range(next_values.shape[0])):
- g = traj.rews[i, :] + traj.gammas[i, :] * \
- ((1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g)
- returns.insert(0, g)
- lambda_returns = torch.from_numpy(np.array(returns))
- pi, values = policy(torch.from_numpy(traj.obs))
- log_probs = pi.log_prob(torch.from_numpy(traj.acs))
- advs = lambda_returns - torch.squeeze(values, -1)
- action_loss = -(advs.detach() * log_probs).mean()
- value_loss = advs.pow(2).mean()
-
- a2c_loss = action_loss + value_coef * value_loss
- return a2c_loss
+ lambdas = np.ones_like(traj.gammas) * LAMBDA
+ _, next_values = policy(torch.from_numpy(traj.next_obs))
+ next_values = torch.squeeze(next_values, -1).detach().numpy()
+ # Work backwards to compute `G_{T-1}`, ..., `G_0`.
+ returns = []
+ g = next_values[-1, :]
+ for i in reversed(range(next_values.shape[0])):
+ g = traj.rews[i, :] + traj.gammas[i, :] * \
+ ((1 - lambdas[i, :]) * next_values[i, :] + lambdas[i, :] * g)
+ returns.insert(0, g)
+ lambda_returns = torch.from_numpy(np.array(returns))
+ pi, values = policy(torch.from_numpy(traj.obs))
+ log_probs = pi.log_prob(torch.from_numpy(traj.acs))
+ advs = lambda_returns - torch.squeeze(values, -1)
+ action_loss = -(advs.detach() * log_probs).mean()
+ value_loss = advs.pow(2).mean()
+
+ a2c_loss = action_loss + value_coef * value_loss
+ return a2c_loss
def evaluate(env, seed, task_num, policy):
- pre_reward_ls = []
- post_reward_ls = []
- inner_opt = TorchOpt.MetaSGD(policy, lr=0.5)
- env = gym.make(
- 'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM,
- num_actions=ACTION_DIM,
- max_episode_steps=TRAJ_LEN,
- seed=args.seed
+ pre_reward_ls = []
+ post_reward_ls = []
+ inner_opt = torchopt.MetaSGD(policy, lr=0.5)
+ env = gym.make(
+ 'TabularMDP-v0',
+ **dict(
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed
+ )
)
- )
- tasks = env.sample_tasks(num_tasks=task_num)
- policy_state_dict = TorchOpt.extract_state_dict(policy)
- optim_state_dict = TorchOpt.extract_state_dict(inner_opt)
- for idx in range(task_num):
- for _ in range(inner_iters):
- pre_trajs = sample_traj(env, tasks[idx], policy)
+ tasks = env.sample_tasks(num_tasks=task_num)
+ policy_state_dict = torchopt.extract_state_dict(policy)
+ optim_state_dict = torchopt.extract_state_dict(inner_opt)
+ for idx in range(task_num):
+ for _ in range(inner_iters):
+ pre_trajs = sample_traj(env, tasks[idx], policy)
- inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
- inner_opt.step(inner_loss)
- post_trajs = sample_traj(env, tasks[idx], policy)
+ inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
+ inner_opt.step(inner_loss)
+ post_trajs = sample_traj(env, tasks[idx], policy)
- # Logging
- pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
- post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
+ # Logging
+ pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
+ post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
- TorchOpt.recover_state_dict(policy, policy_state_dict)
- TorchOpt.recover_state_dict(inner_opt, optim_state_dict)
- return pre_reward_ls, post_reward_ls
+ torchopt.recover_state_dict(policy, policy_state_dict)
+ torchopt.recover_state_dict(inner_opt, optim_state_dict)
+ return pre_reward_ls, post_reward_ls
def main(args):
- # init training
- torch.manual_seed(args.seed)
- torch.cuda.manual_seed_all(args.seed)
- # Env
- env = gym.make(
- 'TabularMDP-v0',
- **dict(
- num_states=STATE_DIM,
- num_actions=ACTION_DIM,
- max_episode_steps=TRAJ_LEN,
- seed=args.seed
- )
- )
- # Policy
- policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
- inner_opt = TorchOpt.MetaSGD(policy, lr=0.5)
- outer_opt = optim.Adam(policy.parameters(), lr=1e-3)
- train_pre_reward = []
- train_post_reward = []
- test_pre_reward = []
- test_post_reward = []
-
- for i in range(outer_iters):
- tasks = env.sample_tasks(num_tasks=TASK_NUM)
- train_pre_reward_ls = []
- train_post_reward_ls = []
-
- outer_opt.zero_grad()
-
- policy_state_dict = TorchOpt.extract_state_dict(policy)
- optim_state_dict = TorchOpt.extract_state_dict(inner_opt)
- for idx in range(TASK_NUM):
-
- for _ in range(inner_iters):
- pre_trajs = sample_traj(env, tasks[idx], policy)
- inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
- inner_opt.step(inner_loss)
- post_trajs = sample_traj(env, tasks[idx], policy)
- outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
- outer_loss.backward()
- TorchOpt.recover_state_dict(policy, policy_state_dict)
- TorchOpt.recover_state_dict(inner_opt, optim_state_dict)
- # Logging
- train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
- train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
- outer_opt.step()
-
- test_pre_reward_ls, test_post_reward_ls = evaluate(
- env, args.seed, TASK_NUM, policy
+ # init training
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ # Env
+ env = gym.make(
+ 'TabularMDP-v0',
+ **dict(
+ num_states=STATE_DIM,
+ num_actions=ACTION_DIM,
+ max_episode_steps=TRAJ_LEN,
+ seed=args.seed
+ )
)
-
- train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
- train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
- test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
- test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)
-
- print('Train_iters', i)
- print("train_pre_reward", sum(train_pre_reward_ls) / TASK_NUM)
- print("train_post_reward", sum(train_post_reward_ls) / TASK_NUM)
- print("test_pre_reward", sum(test_pre_reward_ls) / TASK_NUM)
- print("test_post_reward", sum(test_post_reward_ls) / TASK_NUM)
+ # Policy
+ policy = CategoricalMLPPolicy(input_size=STATE_DIM, output_size=ACTION_DIM)
+ inner_opt = torchopt.MetaSGD(policy, lr=0.5)
+ outer_opt = optim.Adam(policy.parameters(), lr=1e-3)
+ train_pre_reward = []
+ train_post_reward = []
+ test_pre_reward = []
+ test_post_reward = []
+
+ for i in range(outer_iters):
+ tasks = env.sample_tasks(num_tasks=TASK_NUM)
+ train_pre_reward_ls = []
+ train_post_reward_ls = []
+
+ outer_opt.zero_grad()
+
+ policy_state_dict = torchopt.extract_state_dict(policy)
+ optim_state_dict = torchopt.extract_state_dict(inner_opt)
+ for idx in range(TASK_NUM):
+
+ for _ in range(inner_iters):
+ pre_trajs = sample_traj(env, tasks[idx], policy)
+ inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
+ inner_opt.step(inner_loss)
+ post_trajs = sample_traj(env, tasks[idx], policy)
+ outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
+ outer_loss.backward()
+ torchopt.recover_state_dict(policy, policy_state_dict)
+ torchopt.recover_state_dict(inner_opt, optim_state_dict)
+ # Logging
+ train_pre_reward_ls.append(np.sum(pre_trajs.rews, axis=0).mean())
+ train_post_reward_ls.append(np.sum(post_trajs.rews, axis=0).mean())
+ outer_opt.step()
+
+ test_pre_reward_ls, test_post_reward_ls = evaluate(env, args.seed, TASK_NUM, policy)
+
+ train_pre_reward.append(sum(train_pre_reward_ls) / TASK_NUM)
+ train_post_reward.append(sum(train_post_reward_ls) / TASK_NUM)
+ test_pre_reward.append(sum(test_pre_reward_ls) / TASK_NUM)
+ test_post_reward.append(sum(test_post_reward_ls) / TASK_NUM)
+
+ print('Train_iters', i)
+ print("train_pre_reward", sum(train_pre_reward_ls) / TASK_NUM)
+ print("train_post_reward", sum(train_post_reward_ls) / TASK_NUM)
+ print("test_pre_reward", sum(test_pre_reward_ls) / TASK_NUM)
+ print("test_post_reward", sum(test_post_reward_ls) / TASK_NUM)
if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description='Reinforcement learning with '
- 'Model-Agnostic Meta-Learning (MAML) - Train'
- )
- parser.add_argument(
- '--seed', type=int, default=1, help='random seed (default: 1)'
- )
- args = parser.parse_args()
- main(args)
+ parser = argparse.ArgumentParser(
+ description='Reinforcement learning with '
+ 'Model-Agnostic Meta-Learning (MAML) - Train'
+ )
+ parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/MGRL/README.md b/examples/MGRL/README.md
index 65299729..e2952d12 100644
--- a/examples/MGRL/README.md
+++ b/examples/MGRL/README.md
@@ -1,8 +1,10 @@
# MGRL-examples
-Code on toy example of meta-learning the discount factor in paper [Meta-Gradient Reinforcement Learning](https://arxiv.org/abs/1805.09801) using `TorchOpt`. We use `MetaSGD` as the inner-loop optimiser.
+Code on toy example of meta-learning the discount factor in paper [Meta-Gradient Reinforcement Learning](https://arxiv.org/abs/1805.09801) using TorchOpt. We use `MetaSGD` as the inner-loop optimiser.
+
+## Usage
-# Usage
```bash
-### Run
+### Run
python3 toy.py
+```
diff --git a/examples/MGRL/toy.py b/examples/MGRL/toy.py
index 5ce5ad1c..4f0feeb3 100644
--- a/examples/MGRL/toy.py
+++ b/examples/MGRL/toy.py
@@ -14,71 +14,71 @@
# ==============================================================================
import torch
-from torch import nn
-from torch.nn import functional as F
+import torch.nn as nn
+import torch.nn.functional as F
-import TorchOpt
+import torchopt
def test_gamma():
- class Rollout:
-
- @staticmethod
- def get():
- out = torch.empty(5, 2)
- out[:, 0] = torch.randn(5)
- out[:, 1] = 0.1 * torch.ones(5)
- label = torch.arange(0, 10)
- return out.view(10, 1), F.one_hot(label, 10)
-
- @staticmethod
- def rollout(trajectory, gamma):
- out = [trajectory[-1]]
- for i in reversed(range(9)):
- out.append(trajectory[i] + gamma[i] * out[-1].clone().detach_())
- out.reverse()
- return torch.hstack(out).view(10, 1)
-
- class ValueNetwork(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(10, 1)
-
- def forward(self, x):
- return self.fc(x)
-
- torch.manual_seed(0)
- inner_iters = 1
- outer_iters = 10000
- net = ValueNetwork()
- inner_optimizer = TorchOpt.MetaSGD(net, lr=5e-1)
- gamma = torch.zeros(9, requires_grad=True)
- meta_optimizer = TorchOpt.SGD([gamma], lr=5e-1)
- net_state = TorchOpt.extract_state_dict(net)
- for i in range(outer_iters):
- for j in range(inner_iters):
- trajectory, state = Rollout.get()
- backup = Rollout.rollout(trajectory, torch.sigmoid(gamma))
- pred_value = net(state.float())
-
- loss = F.mse_loss(pred_value, backup)
- inner_optimizer.step(loss)
-
- trajectory, state = Rollout.get()
- pred_value = net(state.float())
- backup = Rollout.rollout(trajectory, torch.ones_like(gamma))
-
- loss = F.mse_loss(pred_value, backup)
- meta_optimizer.zero_grad()
- loss.backward()
- meta_optimizer.step()
- TorchOpt.recover_state_dict(net, net_state)
- if i % 100 == 0:
- with torch.no_grad():
- print(f"epoch {i} | gamma: {torch.sigmoid(gamma)}")
+ class Rollout:
+
+ @staticmethod
+ def get():
+ out = torch.empty(5, 2)
+ out[:, 0] = torch.randn(5)
+ out[:, 1] = 0.1 * torch.ones(5)
+ label = torch.arange(0, 10)
+ return out.view(10, 1), F.one_hot(label, 10)
+
+ @staticmethod
+ def rollout(trajectory, gamma):
+ out = [trajectory[-1]]
+ for i in reversed(range(9)):
+ out.append(trajectory[i] + gamma[i] * out[-1].clone().detach_())
+ out.reverse()
+ return torch.hstack(out).view(10, 1)
+
+ class ValueNetwork(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.fc = nn.Linear(10, 1)
+
+ def forward(self, x):
+ return self.fc(x)
+
+ torch.manual_seed(0)
+ inner_iters = 1
+ outer_iters = 10000
+ net = ValueNetwork()
+ inner_optimizer = torchopt.MetaSGD(net, lr=5e-1)
+ gamma = torch.zeros(9, requires_grad=True)
+ meta_optimizer = torchopt.SGD([gamma], lr=5e-1)
+ net_state = torchopt.extract_state_dict(net)
+ for i in range(outer_iters):
+ for j in range(inner_iters):
+ trajectory, state = Rollout.get()
+ backup = Rollout.rollout(trajectory, torch.sigmoid(gamma))
+ pred_value = net(state.float())
+
+ loss = F.mse_loss(pred_value, backup)
+ inner_optimizer.step(loss)
+
+ trajectory, state = Rollout.get()
+ pred_value = net(state.float())
+ backup = Rollout.rollout(trajectory, torch.ones_like(gamma))
+
+ loss = F.mse_loss(pred_value, backup)
+ meta_optimizer.zero_grad()
+ loss.backward()
+ meta_optimizer.step()
+ torchopt.recover_state_dict(net, net_state)
+ if i % 100 == 0:
+ with torch.no_grad():
+ print(f"epoch {i} | gamma: {torch.sigmoid(gamma)}")
if __name__ == "__main__":
- test_gamma()
+ test_gamma()
diff --git a/examples/few-shot/README.md b/examples/few-shot/README.md
index d617b62d..0437541a 100644
--- a/examples/few-shot/README.md
+++ b/examples/few-shot/README.md
@@ -1,15 +1,18 @@
# MAML few-shot Omniglot classification-examples
-Code On MAML few-shot Omniglot classification in paper [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) using `TorchOpt`. We use `MetaSGD` as the inner-loop optimiser.
+Code On MAML few-shot Omniglot classification in paper [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) using TorchOpt. We use `MetaSGD` as the inner-loop optimiser.
+
+## Usage
-# Usage
```bash
-### Run
+### Run
python3 maml-omniglot.py
```
-# Results
+## Results
+
The figure illustrate the experimental result.
+
-

+
diff --git a/examples/few-shot/maml-omniglot.py b/examples/few-shot/maml-omniglot.py
index 1d942593..856f8f01 100644
--- a/examples/few-shot/maml-omniglot.py
+++ b/examples/few-shot/maml-omniglot.py
@@ -47,233 +47,224 @@
import numpy as np
import pandas as pd
import torch
+import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
-from support.omniglot_loaders import OmniglotNShot
-from torch import nn
-import TorchOpt
+import torchopt
+
+from .support.omniglot_loaders import OmniglotNShot
mpl.use('Agg')
plt.style.use('bmh')
def main():
- argparser = argparse.ArgumentParser()
- argparser.add_argument('--n_way', type=int, help='n way', default=5)
- argparser.add_argument(
- '--k_spt', type=int, help='k shot for support set', default=5
- )
- argparser.add_argument(
- '--k_qry', type=int, help='k shot for query set', default=15
- )
- argparser.add_argument(
- '--task_num',
- type=int,
- help='meta batch size, namely task num',
- default=32
- )
- argparser.add_argument('--seed', type=int, help='random seed', default=1)
- args = argparser.parse_args()
-
- torch.manual_seed(args.seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(args.seed)
- np.random.seed(args.seed)
- rng = np.random.default_rng(args.seed)
-
- # Set up the Omniglot loader.
- device = torch.device('cuda:0')
- db = OmniglotNShot(
- '/tmp/omniglot-data',
- batchsz=args.task_num,
- n_way=args.n_way,
- k_shot=args.k_spt,
- k_query=args.k_qry,
- imgsz=28,
- rng=rng,
- device=device,
- )
-
- # Create a vanilla PyTorch neural network that will be
- # automatically monkey-patched by higher later.
- # Before higher, models could *not* be created like this
- # and the parameters needed to be manually updated and copied
- # for the updates.
- net = nn.Sequential(
- nn.Conv2d(1, 64, 3), nn.BatchNorm2d(64, momentum=1., affine=True),
- nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3),
- nn.BatchNorm2d(64, momentum=1., affine=True), nn.ReLU(inplace=False),
- nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3),
- nn.BatchNorm2d(64, momentum=1., affine=True), nn.ReLU(inplace=False),
- nn.MaxPool2d(2, 2), nn.Flatten(), nn.Linear(64, args.n_way)
- ).to(device)
-
- # We will use Adam to (meta-)optimize the initial parameters
- # to be adapted.
- meta_opt = optim.Adam(net.parameters(), lr=1e-3)
-
- log = []
- for epoch in range(10):
- train(db, net, meta_opt, epoch, log)
- test(db, net, epoch, log)
- plot(log)
-
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument('--n_way', type=int, help='n way', default=5)
+ argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
+ argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
+ argparser.add_argument(
+ '--task_num', type=int, help='meta batch size, namely task num', default=32
+ )
+ argparser.add_argument('--seed', type=int, help='random seed', default=1)
+ args = argparser.parse_args()
+
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+
+ # Set up the Omniglot loader.
+ device = torch.device('cuda:0')
+ db = OmniglotNShot(
+ '/tmp/omniglot-data',
+ batchsz=args.task_num,
+ n_way=args.n_way,
+ k_shot=args.k_spt,
+ k_query=args.k_qry,
+ imgsz=28,
+ rng=rng,
+ device=device,
+ )
-def train(db, net, meta_opt, epoch, log):
- net.train()
- n_train_iter = db.x_train.shape[0] // db.batchsz
- inner_opt = TorchOpt.MetaSGD(net, lr=1e-1)
+ # Create a vanilla PyTorch neural network that will be
+ # automatically monkey-patched by higher later.
+ # Before higher, models could *not* be created like this
+ # and the parameters needed to be manually updated and copied
+ # for the updates.
+ net = nn.Sequential(
+ nn.Conv2d(1, 64, 3), nn.BatchNorm2d(64, momentum=1., affine=True), nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64, momentum=1., affine=True),
+ nn.ReLU(inplace=False), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1., affine=True), nn.ReLU(inplace=False), nn.MaxPool2d(2, 2),
+ nn.Flatten(), nn.Linear(64, args.n_way)
+ ).to(device)
+
+ # We will use Adam to (meta-)optimize the initial parameters
+ # to be adapted.
+ meta_opt = optim.Adam(net.parameters(), lr=1e-3)
+
+ log = []
+ for epoch in range(10):
+ train(db, net, meta_opt, epoch, log)
+ test(db, net, epoch, log)
+ plot(log)
- for batch_idx in range(n_train_iter):
- start_time = time.time()
- # Sample a batch of support and query images and labels.
- x_spt, y_spt, x_qry, y_qry = db.next()
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
+def train(db, net, meta_opt, epoch, log):
+ net.train()
+ n_train_iter = db.x_train.shape[0] // db.batchsz
+ inner_opt = torchopt.MetaSGD(net, lr=1e-1)
+
+ for batch_idx in range(n_train_iter):
+ start_time = time.time()
+ # Sample a batch of support and query images and labels.
+ x_spt, y_spt, x_qry, y_qry = db.next()
+
+ task_num, setsz, c_, h, w = x_spt.size()
+ querysz = x_qry.size(1)
+
+ # TODO: Maybe pull this out into a separate module so it
+ # doesn't have to be duplicated between `train` and `test`?
+
+ # Initialize the inner optimizer to adapt the parameters to
+ # the support set.
+ n_inner_iter = 5
+
+ qry_losses = []
+ qry_accs = []
+ meta_opt.zero_grad()
+
+ net_state_dict = torchopt.extract_state_dict(net)
+ optim_state_dict = torchopt.extract_state_dict(inner_opt)
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+ # higher is able to automatically keep copies of
+ # your network's parameters as they are being updated.
+ for _ in range(n_inner_iter):
+ spt_logits = net(x_spt[i])
+ spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+ inner_opt.step(spt_loss)
+
+ # The final set of adapted parameters will induce some
+ # final loss and accuracy on the query dataset.
+ # These will be used to update the model's meta-parameters.
+ qry_logits = net(x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ qry_losses.append(qry_loss.detach())
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
+ qry_accs.append(qry_acc)
+
+ # Update the model's meta-parameters to optimize the query
+ # losses across all of the tasks sampled in this batch.
+ # This unrolls through the gradient steps.
+ qry_loss.backward()
+
+ torchopt.recover_state_dict(net, net_state_dict)
+ torchopt.recover_state_dict(inner_opt, optim_state_dict)
+
+ meta_opt.step()
+ qry_losses = sum(qry_losses) / task_num
+ qry_accs = 100. * sum(qry_accs) / task_num
+ i = epoch + float(batch_idx) / n_train_iter
+ iter_time = time.time() - start_time
+
+ print(
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ )
+
+ log.append(
+ {
+ 'epoch': i,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'train',
+ 'time': time.time(),
+ }
+ )
- # TODO: Maybe pull this out into a separate module so it
- # doesn't have to be duplicated between `train` and `test`?
- # Initialize the inner optimizer to adapt the parameters to
- # the support set.
- n_inner_iter = 5
+def test(db, net, epoch, log):
+ # Crucially in our testing procedure here, we do *not* fine-tune
+ # the model during testing for simplicity.
+ # Most research papers using MAML for this task do an extra
+ # stage of fine-tuning here that should be added if you are
+ # adapting this code for research.
+ net.train()
+ n_test_iter = db.x_test.shape[0] // db.batchsz
+ inner_opt = torchopt.MetaSGD(net, lr=1e-1)
qry_losses = []
qry_accs = []
- meta_opt.zero_grad()
-
- net_state_dict = TorchOpt.extract_state_dict(net)
- optim_state_dict = TorchOpt.extract_state_dict(inner_opt)
- for i in range(task_num):
- # Optimize the likelihood of the support set by taking
- # gradient steps w.r.t. the model's parameters.
- # This adapts the model's meta-parameters to the task.
- # higher is able to automatically keep copies of
- # your network's parameters as they are being updated.
- for _ in range(n_inner_iter):
- spt_logits = net(x_spt[i])
- spt_loss = F.cross_entropy(spt_logits, y_spt[i])
- inner_opt.step(spt_loss)
-
- # The final set of adapted parameters will induce some
- # final loss and accuracy on the query dataset.
- # These will be used to update the model's meta-parameters.
- qry_logits = net(x_qry[i])
- qry_loss = F.cross_entropy(qry_logits, y_qry[i])
- qry_losses.append(qry_loss.detach())
- qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
- qry_accs.append(qry_acc)
-
- # Update the model's meta-parameters to optimize the query
- # losses across all of the tasks sampled in this batch.
- # This unrolls through the gradient steps.
- qry_loss.backward()
-
- TorchOpt.recover_state_dict(net, net_state_dict)
- TorchOpt.recover_state_dict(inner_opt, optim_state_dict)
-
- meta_opt.step()
- qry_losses = sum(qry_losses) / task_num
- qry_accs = 100. * sum(qry_accs) / task_num
- i = epoch + float(batch_idx) / n_train_iter
- iter_time = time.time() - start_time
-
- print(
- f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
- )
+ for batch_idx in range(n_test_iter):
+ x_spt, y_spt, x_qry, y_qry = db.next('test')
+
+ task_num, setsz, c_, h, w = x_spt.size()
+ querysz = x_qry.size(1)
+
+ # TODO: Maybe pull this out into a separate module so it
+ # doesn't have to be duplicated between `train` and `test`?
+ n_inner_iter = 5
+
+ net_state_dict = torchopt.extract_state_dict(net)
+ optim_state_dict = torchopt.extract_state_dict(inner_opt)
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+ for _ in range(n_inner_iter):
+ spt_logits = net(x_spt[i])
+ spt_loss = F.cross_entropy(spt_logits, y_spt[i])
+ inner_opt.step(spt_loss)
+
+ # The query loss and acc induced by these parameters.
+ qry_logits = net(x_qry[i]).detach()
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
+ qry_losses.append(qry_loss.detach())
+ qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
+
+ torchopt.recover_state_dict(net, net_state_dict)
+ torchopt.recover_state_dict(inner_opt, optim_state_dict)
+
+ qry_losses = torch.cat(qry_losses).mean().item()
+ qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
+ print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
log.append(
- {
- 'epoch': i,
- 'loss': qry_losses,
- 'acc': qry_accs,
- 'mode': 'train',
- 'time': time.time(),
- }
+ {
+ 'epoch': epoch + 1,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'test',
+ 'time': time.time(),
+ }
)
-def test(db, net, epoch, log):
- # Crucially in our testing procedure here, we do *not* fine-tune
- # the model during testing for simplicity.
- # Most research papers using MAML for this task do an extra
- # stage of fine-tuning here that should be added if you are
- # adapting this code for research.
- net.train()
- n_test_iter = db.x_test.shape[0] // db.batchsz
- inner_opt = TorchOpt.MetaSGD(net, lr=1e-1)
-
- qry_losses = []
- qry_accs = []
-
- for batch_idx in range(n_test_iter):
- x_spt, y_spt, x_qry, y_qry = db.next('test')
-
- task_num, setsz, c_, h, w = x_spt.size()
- querysz = x_qry.size(1)
-
- # TODO: Maybe pull this out into a separate module so it
- # doesn't have to be duplicated between `train` and `test`?
- n_inner_iter = 5
-
- net_state_dict = TorchOpt.extract_state_dict(net)
- optim_state_dict = TorchOpt.extract_state_dict(inner_opt)
- for i in range(task_num):
- # Optimize the likelihood of the support set by taking
- # gradient steps w.r.t. the model's parameters.
- # This adapts the model's meta-parameters to the task.
- for _ in range(n_inner_iter):
- spt_logits = net(x_spt[i])
- spt_loss = F.cross_entropy(spt_logits, y_spt[i])
- inner_opt.step(spt_loss)
-
- # The query loss and acc induced by these parameters.
- qry_logits = net(x_qry[i]).detach()
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
- qry_losses.append(qry_loss.detach())
- qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
-
- TorchOpt.recover_state_dict(net, net_state_dict)
- TorchOpt.recover_state_dict(inner_opt, optim_state_dict)
-
- qry_losses = torch.cat(qry_losses).mean().item()
- qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
- print(
- f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
- )
- log.append(
- {
- 'epoch': epoch + 1,
- 'loss': qry_losses,
- 'acc': qry_accs,
- 'mode': 'test',
- 'time': time.time(),
- }
- )
-
-
def plot(log):
- # Generally you should pull your plotting code out of your training
- # script but we are doing it here for brevity.
- df = pd.DataFrame(log)
-
- fig, ax = plt.subplots(figsize=(6, 4))
- train_df = df[df['mode'] == 'train']
- test_df = df[df['mode'] == 'test']
- ax.plot(train_df['epoch'], train_df['acc'], label='Train')
- ax.plot(test_df['epoch'], test_df['acc'], label='Test')
- ax.set_xlabel('Epoch')
- ax.set_ylabel('Accuracy')
- ax.set_ylim(70, 100)
- fig.legend(ncol=2, loc='lower right')
- fig.tight_layout()
- fname = 'maml-accs.png'
- print(f'--- Plotting accuracy to {fname}')
- fig.savefig(fname)
- plt.close(fig)
+ # Generally you should pull your plotting code out of your training
+ # script but we are doing it here for brevity.
+ df = pd.DataFrame(log)
+
+ fig, ax = plt.subplots(figsize=(6, 4))
+ train_df = df[df['mode'] == 'train']
+ test_df = df[df['mode'] == 'test']
+ ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+ ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+ ax.set_xlabel('Epoch')
+ ax.set_ylabel('Accuracy')
+ ax.set_ylim(70, 100)
+ fig.legend(ncol=2, loc='lower right')
+ fig.tight_layout()
+ fname = 'maml-accs.png'
+ print(f'--- Plotting accuracy to {fname}')
+ fig.savefig(fname)
+ plt.close(fig)
if __name__ == '__main__':
- main()
+ main()
diff --git a/examples/few-shot/support/omniglot_loaders.py b/examples/few-shot/support/omniglot_loaders.py
index 9aa9f6ed..731c41be 100644
--- a/examples/few-shot/support/omniglot_loaders.py
+++ b/examples/few-shot/support/omniglot_loaders.py
@@ -20,7 +20,6 @@
import errno
import os
-import os.path
import numpy as np
import torch
@@ -30,122 +29,115 @@
class Omniglot(data.Dataset):
- urls = [
- 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
- 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
- ]
- raw_folder = 'raw'
- processed_folder = 'processed'
- training_file = 'training.pt'
- test_file = 'test.pt'
- '''
+ """
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
- '''
-
- def __init__(
- self, root, transform=None, target_transform=None, download=False
- ):
- self.root = root
- self.transform = transform
- self.target_transform = target_transform
-
- if not self._check_exists():
- if download:
- self.download()
- else:
- raise RuntimeError(
- 'Dataset not found.' + ' You can use download=True to download it'
- )
-
- self.all_items = find_classes(
- os.path.join(self.root, self.processed_folder)
- )
- self.idx_classes = index_classes(self.all_items)
-
- def __getitem__(self, index):
- filename = self.all_items[index][0]
- img = str.join('/', [self.all_items[index][2], filename])
-
- target = self.idx_classes[self.all_items[index][1]]
- if self.transform is not None:
- img = self.transform(img)
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return img, target
-
- def __len__(self):
- return len(self.all_items)
-
- def _check_exists(self):
- return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
- os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
-
- def download(self):
- import zipfile
-
- from six.moves import urllib
-
- if self._check_exists():
- return
-
- # download files
- try:
- os.makedirs(os.path.join(self.root, self.raw_folder))
- os.makedirs(os.path.join(self.root, self.processed_folder))
- except OSError as e:
- if e.errno == errno.EEXIST:
- pass
- else:
- raise
-
- for url in self.urls:
- print('== Downloading ' + url)
- data = urllib.request.urlopen(url)
- filename = url.rpartition('/')[2]
- file_path = os.path.join(self.root, self.raw_folder, filename)
- with open(file_path, 'wb') as f:
- f.write(data.read())
- file_processed = os.path.join(self.root, self.processed_folder)
- print("== Unzip from " + file_path + " to " + file_processed)
- zip_ref = zipfile.ZipFile(file_path, 'r')
- zip_ref.extractall(file_processed)
- zip_ref.close()
- print("Download finished.")
+ """
+
+ urls = [
+ 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
+ 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
+ ]
+ raw_folder = 'raw'
+ processed_folder = 'processed'
+ training_file = 'training.pt'
+ test_file = 'test.pt'
+
+ def __init__(self, root, transform=None, target_transform=None, download=False):
+ self.root = root
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if not self._check_exists():
+ if download:
+ self.download()
+ else:
+ raise RuntimeError('Dataset not found. You can use download=True to download it')
+
+ self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
+ self.idx_classes = index_classes(self.all_items)
+
+ def __getitem__(self, index):
+ filename = self.all_items[index][0]
+ img = str.join('/', [self.all_items[index][2], filename])
+
+ target = self.idx_classes[self.all_items[index][1]]
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.all_items)
+
+ def _check_exists(self):
+ return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
+ os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
+
+ def download(self):
+ import zipfile
+
+ from six.moves import urllib
+
+ if self._check_exists():
+ return
+
+ # download files
+ try:
+ os.makedirs(os.path.join(self.root, self.raw_folder))
+ os.makedirs(os.path.join(self.root, self.processed_folder))
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+
+ for url in self.urls:
+ print('== Downloading ' + url)
+ data = urllib.request.urlopen(url)
+ filename = url.rpartition('/')[2]
+ file_path = os.path.join(self.root, self.raw_folder, filename)
+ with open(file_path, 'wb') as f:
+ f.write(data.read())
+ file_processed = os.path.join(self.root, self.processed_folder)
+ print("== Unzip from " + file_path + " to " + file_processed)
+ zip_ref = zipfile.ZipFile(file_path, 'r')
+ zip_ref.extractall(file_processed)
+ zip_ref.close()
+ print("Download finished.")
def find_classes(root_dir):
- retour = []
- for (root, dirs, files) in os.walk(root_dir):
- for f in files:
- if (f.endswith("png")):
- r = root.split('/')
- lr = len(r)
- retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
- print("== Found %d items " % len(retour))
- return retour
+ retour = []
+ for (root, dirs, files) in os.walk(root_dir):
+ for f in files:
+ if (f.endswith("png")):
+ r = root.split('/')
+ lr = len(r)
+ retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
+ print("== Found %d items " % len(retour))
+ return retour
def index_classes(items):
- idx = {}
- for i in items:
- if i[1] not in idx:
- idx[i[1]] = len(idx)
- print("== Found %d classes" % len(idx))
- return idx
+ idx = {}
+ for i in items:
+ if i[1] not in idx:
+ idx[i[1]] = len(idx)
+ print("== Found %d classes" % len(idx))
+ return idx
class OmniglotNShot:
- def __init__(
- self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None
- ):
- """
+ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None):
+ """
Different from mnistNShot, the
:param root:
:param batchsz: task num
@@ -155,178 +147,168 @@ def __init__(
:param imgsz:
"""
- self.resize = imgsz
- self.rng = rng
- self.device = device
- if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
- # if root/data.npy does not exist, just download it
- self.x = Omniglot(
- root,
- download=True,
- transform=transforms.Compose(
- [
- lambda x: Image.open(x).convert('L'),
- lambda x: x.resize((imgsz, imgsz)),
- lambda x: np.reshape(x, (imgsz, imgsz, 1)),
- lambda x: np.transpose(x, [2, 0, 1]), lambda x: x / 255.
- ]
- ),
- )
-
- temp = dict(
- ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
- for (img, label) in self.x:
- if label in temp.keys():
- temp[label].append(img)
+ self.resize = imgsz
+ self.rng = rng
+ self.device = device
+ if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
+ # if root/data.npy does not exist, just download it
+ self.x = Omniglot(
+ root,
+ download=True,
+ transform=transforms.Compose(
+ [
+ lambda x: Image.open(x).convert('L'), lambda x: x.resize((imgsz, imgsz)),
+ lambda x: np.reshape(x, (imgsz, imgsz, 1)),
+ lambda x: np.transpose(x, [2, 0, 1]), lambda x: x / 255.
+ ]
+ ),
+ )
+
+ temp = dict(
+ ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
+ for (img, label) in self.x:
+ if label in temp.keys():
+ temp[label].append(img)
+ else:
+ temp[label] = [img]
+
+ self.x = []
+ for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
+ self.x.append(np.array(imgs))
+
+ # as different class may have different number of imgs
+ self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
+ # each character contains 20 imgs
+ print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
+ temp = [] # Free memory
+ # save all dataset into npy file.
+ np.save(os.path.join(root, 'omniglot.npy'), self.x)
+ print('write into omniglot.npy.')
else:
- temp[label] = [img]
-
- self.x = []
- for label, imgs in temp.items(
- ): # labels info deserted , each label contains 20imgs
- self.x.append(np.array(imgs))
-
- # as different class may have different number of imgs
- self.x = np.array(self.x).astype(
- np.float
- ) # [[20 imgs],..., 1623 classes in total]
- # each character contains 20 imgs
- print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
- temp = [] # Free memory
- # save all dataset into npy file.
- np.save(os.path.join(root, 'omniglot.npy'), self.x)
- print('write into omniglot.npy.')
- else:
- # if data.npy exists, just load it.
- self.x = np.load(os.path.join(root, 'omniglot.npy'))
- print('load from omniglot.npy.')
-
- # [1623, 20, 84, 84, 1]
- # TODO: can not shuffle here, we must keep training and test set distinct!
- self.x_train, self.x_test = self.x[:1200], self.x[1200:]
-
- # self.normalization()
-
- self.batchsz = batchsz
- self.n_cls = self.x.shape[0] # 1623
- self.n_way = n_way # n way
- self.k_shot = k_shot # k shot
- self.k_query = k_query # k query
- assert (k_shot + k_query) <= 20
-
- # save pointer of current read batch in total cache
- self.indexes = {"train": 0, "test": 0}
- self.datasets = {
- "train": self.x_train,
- "test": self.x_test
- } # original data cached
- print("DB: train", self.x_train.shape, "test", self.x_test.shape)
-
- self.datasets_cache = {
- "train": self.load_data_cache(self.datasets["train"]
- ), # current epoch data cached
- "test": self.load_data_cache(self.datasets["test"])
- }
-
- def normalization(self):
- """
- Normalizes our data, to have a mean of 0 and sdt of 1
+ # if data.npy exists, just load it.
+ self.x = np.load(os.path.join(root, 'omniglot.npy'))
+ print('load from omniglot.npy.')
+
+ # [1623, 20, 84, 84, 1]
+ # TODO: can not shuffle here, we must keep training and test set distinct!
+ self.x_train, self.x_test = self.x[:1200], self.x[1200:]
+
+ # self.normalization()
+
+ self.batchsz = batchsz
+ self.n_cls = self.x.shape[0] # 1623
+ self.n_way = n_way # n way
+ self.k_shot = k_shot # k shot
+ self.k_query = k_query # k query
+ assert (k_shot + k_query) <= 20
+
+ # save pointer of current read batch in total cache
+ self.indexes = {"train": 0, "test": 0}
+ self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
+ print("DB: train", self.x_train.shape, "test", self.x_test.shape)
+
+ self.datasets_cache = {
+ "train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
+ "test": self.load_data_cache(self.datasets["test"])
+ }
+
+ def normalization(self):
+ """
+ Normalizes our data, to have a mean of 0 and sdt of 1
+ """
+ self.mean = np.mean(self.x_train)
+ self.std = np.std(self.x_train)
+ self.max = np.max(self.x_train)
+ self.min = np.min(self.x_train)
+ # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+ self.x_train = (self.x_train - self.mean) / self.std
+ self.x_test = (self.x_test - self.mean) / self.std
+
+ self.mean = np.mean(self.x_train)
+ self.std = np.std(self.x_train)
+ self.max = np.max(self.x_train)
+ self.min = np.min(self.x_train)
+
+ # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+
+ def load_data_cache(self, data_pack):
"""
- self.mean = np.mean(self.x_train)
- self.std = np.std(self.x_train)
- self.max = np.max(self.x_train)
- self.min = np.min(self.x_train)
- # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
- self.x_train = (self.x_train - self.mean) / self.std
- self.x_test = (self.x_test - self.mean) / self.std
-
- self.mean = np.mean(self.x_train)
- self.std = np.std(self.x_train)
- self.max = np.max(self.x_train)
- self.min = np.min(self.x_train)
-
- # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
-
- def load_data_cache(self, data_pack):
- """
Collects several batches data for N-shot learning
:param data_pack: [cls_num, 20, 84, 84, 1]
:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
"""
- # take 5 way 1 shot as example: 5 * 1
- setsz = self.k_shot * self.n_way
- querysz = self.k_query * self.n_way
- data_cache = []
-
- # print('preload next 50 caches of batchsz of batch.')
- for sample in range(10): # num of episodes
-
- x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
- for i in range(self.batchsz): # one batch means one set
-
- x_spt, y_spt, x_qry, y_qry = [], [], [], []
- selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
-
- for j, cur_class in enumerate(selected_cls):
-
- selected_img = self.rng.choice(20, self.k_shot + self.k_query, False)
-
- # meta-training and meta-test
- x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
- x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
- y_spt.append([j for _ in range(self.k_shot)])
- y_qry.append([j for _ in range(self.k_query)])
-
- # shuffle inside a batch
- perm = self.rng.permutation(self.n_way * self.k_shot)
- x_spt = np.array(x_spt).reshape(
- self.n_way * self.k_shot, 1, self.resize, self.resize
- )[perm]
- y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
- perm = self.rng.permutation(self.n_way * self.k_query)
- x_qry = np.array(x_qry).reshape(
- self.n_way * self.k_query, 1, self.resize, self.resize
- )[perm]
- y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
-
- # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
- x_spts.append(x_spt)
- y_spts.append(y_spt)
- x_qrys.append(x_qry)
- y_qrys.append(y_qry)
-
- # [b, setsz, 1, 84, 84]
- x_spts = np.array(x_spts).astype(
- np.float32
- ).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
- y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
- # [b, qrysz, 1, 84, 84]
- x_qrys = np.array(x_qrys).astype(
- np.float32
- ).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
- y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)
-
- x_spts, y_spts, x_qrys, y_qrys = [
- torch.from_numpy(z).to(self.device)
- for z in [x_spts, y_spts, x_qrys, y_qrys]
- ]
-
- data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
-
- return data_cache
-
- def next(self, mode='train'):
- """
+
+ # take 5 way 1 shot as example: 5 * 1
+ setsz = self.k_shot * self.n_way
+ querysz = self.k_query * self.n_way
+ data_cache = []
+
+ # print('preload next 50 caches of batchsz of batch.')
+ for sample in range(10): # num of episodes
+
+ x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
+ for i in range(self.batchsz): # one batch means one set
+
+ x_spt, y_spt, x_qry, y_qry = [], [], [], []
+ selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
+
+ for j, cur_class in enumerate(selected_cls):
+
+ selected_img = self.rng.choice(20, self.k_shot + self.k_query, False)
+
+ # meta-training and meta-test
+ x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
+ x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
+ y_spt.append([j for _ in range(self.k_shot)])
+ y_qry.append([j for _ in range(self.k_query)])
+
+ # shuffle inside a batch
+ perm = self.rng.permutation(self.n_way * self.k_shot)
+ x_spt = np.array(x_spt) \
+ .reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
+ y_spt = np.array(y_spt) \
+ .reshape(self.n_way * self.k_shot)[perm]
+ perm = self.rng.permutation(self.n_way * self.k_query)
+ x_qry = np.array(x_qry) \
+ .reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
+ y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
+
+ # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
+ x_spts.append(x_spt)
+ y_spts.append(y_spt)
+ x_qrys.append(x_qry)
+ y_qrys.append(y_qry)
+
+ # [b, setsz, 1, 84, 84]
+ x_spts = np.array(x_spts, dtype=np.float32) \
+ .reshape(self.batchsz, setsz, 1, self.resize, self.resize)
+ y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz)
+ # [b, qrysz, 1, 84, 84]
+ x_qrys = np.array(x_qrys, dtype=np.float32) \
+ .reshape(self.batchsz, querysz, 1, self.resize, self.resize)
+ y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz)
+
+ x_spts, y_spts, x_qrys, y_qrys = [
+ torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]
+ ]
+
+ data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
+
+ return data_cache
+
+ def next(self, mode='train'):
+ """
Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test")
:return:
"""
- # update cache if indexes is larger cached num
- if self.indexes[mode] >= len(self.datasets_cache[mode]):
- self.indexes[mode] = 0
- self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
- next_batch = self.datasets_cache[mode][self.indexes[mode]]
- self.indexes[mode] += 1
+ # update cache if indexes is larger cached num
+ if self.indexes[mode] >= len(self.datasets_cache[mode]):
+ self.indexes[mode] = 0
+ self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
+
+ next_batch = self.datasets_cache[mode][self.indexes[mode]]
+ self.indexes[mode] += 1
- return next_batch
+ return next_batch
diff --git a/examples/visualize.py b/examples/visualize.py
index 4e7d2684..028669e9 100644
--- a/examples/visualize.py
+++ b/examples/visualize.py
@@ -14,73 +14,67 @@
# ==============================================================================
import torch
+import torch.nn as nn
+import torch.nn.functional as F
import torchviz
-from torch import nn
-from torch.nn import functional as F
-import TorchOpt
+import torchopt
class Net(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.fc = nn.Linear(dim, 1)
+ def __init__(self, dim):
+ super().__init__()
+ self.fc = nn.Linear(dim, 1)
- def forward(self, x, meta_param):
- return self.fc(x) + meta_param
+ def forward(self, x, meta_param):
+ return self.fc(x) + meta_param
def draw_torchviz():
- net = Net(dim).cuda()
- optimizer = TorchOpt.MetaAdam(net, lr=1e-3, use_accelerated_op=False)
- meta_param = torch.tensor(1., requires_grad=True)
-
- xs = torch.ones(batch_size, dim).cuda()
-
- pred = net(xs, meta_param)
- loss = F.mse_loss(pred, torch.ones_like(pred))
- optimizer.step(loss)
-
- pred = net(xs, meta_param)
- loss = F.mse_loss(pred, torch.ones_like(pred))
- # draw computation graph
- torchviz.make_dot(loss).render("torchviz_graph", format="svg")
-
-
-def draw_TorchOpt():
- net = Net(dim).cuda()
- optimizer = TorchOpt.MetaAdam(net, lr=1e-3, use_accelerated_op=True)
- meta_param = torch.tensor(1., requires_grad=True)
-
- xs = torch.ones(batch_size, dim).cuda()
-
- pred = net(xs, meta_param)
- loss = F.mse_loss(pred, torch.ones_like(pred))
- # set enable_visual
- net_state_0 = TorchOpt.extract_state_dict(
- net, enable_visual=True, visual_prefix='step0.'
- )
- optimizer.step(loss)
- # set enable_visual
- net_state_1 = TorchOpt.extract_state_dict(
- net, enable_visual=True, visual_prefix='step1.'
- )
-
- pred = net(xs, meta_param)
- loss = F.mse_loss(pred, torch.ones_like(pred))
- # draw computation graph
- TorchOpt.visual.make_dot(
- loss, [net_state_0, net_state_1, {
- meta_param: "meta_param"
- }]
- ).render(
- "TorchOpt_graph", format="svg"
- )
+ net = Net(dim).cuda()
+ optimizer = torchopt.MetaAdam(net, lr=1e-3, use_accelerated_op=False)
+ meta_param = torch.tensor(1., requires_grad=True)
+
+ xs = torch.ones(batch_size, dim).cuda()
+
+ pred = net(xs, meta_param)
+ loss = F.mse_loss(pred, torch.ones_like(pred))
+ optimizer.step(loss)
+
+ pred = net(xs, meta_param)
+ loss = F.mse_loss(pred, torch.ones_like(pred))
+ # draw computation graph
+ torchviz.make_dot(loss).render("torchviz_graph", format="svg")
+
+
+def draw_torchopt():
+ net = Net(dim).cuda()
+ optimizer = torchopt.MetaAdam(net, lr=1e-3, use_accelerated_op=True)
+ meta_param = torch.tensor(1., requires_grad=True)
+
+ xs = torch.ones(batch_size, dim).cuda()
+
+ pred = net(xs, meta_param)
+ loss = F.mse_loss(pred, torch.ones_like(pred))
+ # set enable_visual
+ net_state_0 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step0.')
+ optimizer.step(loss)
+ # set enable_visual
+ net_state_1 = torchopt.extract_state_dict(net, enable_visual=True, visual_prefix='step1.')
+
+ pred = net(xs, meta_param)
+ loss = F.mse_loss(pred, torch.ones_like(pred))
+ # draw computation graph
+ torchopt.visual.make_dot(loss, [net_state_0, net_state_1, {
+ meta_param: "meta_param"
+ }]).render(
+ "torchopt_graph", format="svg"
+ )
if __name__ == '__main__':
- dim = 5
- batch_size = 2
- draw_torchviz()
- draw_TorchOpt()
+ dim = 5
+ batch_size = 2
+ draw_torchviz()
+ draw_torchopt()
diff --git a/include/adam_op/adam_op.h b/include/adam_op/adam_op.h
index 3499a3e9..33aa53b7 100644
--- a/include/adam_op/adam_op.h
+++ b/include/adam_op/adam_op.h
@@ -18,9 +18,9 @@
#include
-#include "adam_op/common.h"
+#include "common.h"
-namespace TorchOpt {
+namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
@@ -51,4 +51,4 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const int count);
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/include/adam_op/adam_op_impl.cuh b/include/adam_op/adam_op_impl.cuh
index 9e37df1b..bc29171f 100644
--- a/include/adam_op/adam_op_impl.cuh
+++ b/include/adam_op/adam_op_impl.cuh
@@ -18,9 +18,9 @@
#include
-#include "adam_op/common.h"
+#include "common.h"
-namespace TorchOpt {
+namespace torchopt {
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu, const float b1,
@@ -53,4 +53,4 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const torch::Tensor &new_nu,
const float b1, const float b2,
const int count);
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/include/adam_op/adam_op_impl.h b/include/adam_op/adam_op_impl.h
index 96393d16..2514aa48 100644
--- a/include/adam_op/adam_op_impl.h
+++ b/include/adam_op/adam_op_impl.h
@@ -18,9 +18,9 @@
#include
-#include "adam_op/common.h"
+#include "common.h"
-namespace TorchOpt {
+namespace torchopt {
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
@@ -52,4 +52,4 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
const torch::Tensor& new_nu,
const float b1, const float b2,
const int count);
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/include/common.h b/include/common.h
index e5c681b6..e4362013 100644
--- a/include/common.h
+++ b/include/common.h
@@ -18,7 +18,7 @@
#include
-namespace TorchOpt {
+namespace torchopt {
template
using TensorArray = std::array;
}
diff --git a/include/utils.h b/include/utils.h
index ddc0a992..92f9bad0 100644
--- a/include/utils.h
+++ b/include/utils.h
@@ -22,7 +22,7 @@
#define __forceinline__ __inline__ __attribute__((always_inline))
#endif
-namespace TorchOpt {
+namespace torchopt {
__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
const auto dim = tensor.dim();
size_t n = 1;
@@ -31,4 +31,4 @@ __forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
}
return n;
}
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/setup.cfg b/setup.cfg
index 52dc6283..f43fc9bc 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,14 +1,15 @@
[yapf]
based_on_style = yapf
+indent_width = 4
+continuation_indent_width = 4
spaces_before_comment = 2
dedent_closing_brackets = true
-column_limit = 79
-continuation_indent_width = 2
+column_limit = 100
[flake8]
exclude =
.git
-indent_size = 2
+indent_size = 4
[pydocstyle]
convention = google
@@ -16,7 +17,7 @@ convention = google
[isort]
profile = black
multi_line_output = 3
-indent = 2
+indent = 4
line_length = 79
[mypy]
@@ -39,4 +40,4 @@ warn_unused_configs = True
warn_unused_ignores = True
[doc8]
-max-line-length = 200
\ No newline at end of file
+max-line-length = 200
diff --git a/setup.py b/setup.py
index 9c201878..1c4df4a0 100644
--- a/setup.py
+++ b/setup.py
@@ -8,113 +8,103 @@
class MyBuild(build_ext):
-
- def run(self):
- self.build_cmake()
-
- def copy(self, build_temp):
- from distutils.file_util import copy_file
- cwd = str(pathlib.Path().absolute())
- src = os.path.join('.', build_temp, 'src')
- ops = os.listdir(src)
- for op in ops:
- op_path = os.path.join(src, op)
- if not os.path.isdir(op_path):
- continue
- files = os.listdir(op_path)
- for file in files:
- if file.split('.')[-1] == 'so':
- copy_file(
- os.path.join(op_path, file), os.path.join(cwd, 'TorchOpt', '_lib')
- )
-
- def build_cmake(self):
- cwd = pathlib.Path().absolute()
-
- build_temp = f"{pathlib.Path(self.build_temp)}"
- os.makedirs(build_temp, exist_ok=True)
-
- config = "Debug" if self.debug else "Release"
-
- PYTHON_INCLUDE_DIR = ""
- for path in self.include_dirs:
- PYTHON_INCLUDE_DIR += path + ';'
-
- TORCH_INCLUDE_PATH = ""
- for path in cpp_extension.include_paths():
- TORCH_INCLUDE_PATH += path + ';'
-
- TORCH_LIBRARY_PATH = ""
- for path in cpp_extension.library_paths():
- TORCH_LIBRARY_PATH += path + ';'
-
- cmake_args = [
- "-DPYTHON_INCLUDE_DIR=" + PYTHON_INCLUDE_DIR,
- "-DTORCH_INCLUDE_PATH=" + TORCH_INCLUDE_PATH,
- "-DTORCH_LIBRARY_PATH=" + TORCH_LIBRARY_PATH,
- "-DCMAKE_BUILD_TYPE=" + config
- ]
-
- build_args = ["--config", config, "--", "-j4"]
-
- os.chdir(build_temp)
- self.spawn(["cmake", f"{str(cwd)}"] + cmake_args)
- if not self.dry_run:
- self.spawn(["cmake", "--build", "."] + build_args)
- os.chdir(str(cwd))
- self.copy(build_temp)
+ def run(self):
+ self.build_cmake()
+
+ def copy(self, build_temp):
+ from distutils.file_util import copy_file
+ cwd = str(pathlib.Path().absolute())
+ src = os.path.join('.', build_temp, 'src')
+ ops = os.listdir(src)
+ for op in ops:
+ op_path = os.path.join(src, op)
+ if not os.path.isdir(op_path):
+ continue
+ files = os.listdir(op_path)
+ for file in files:
+ if file.split('.')[-1] == 'so':
+ copy_file(os.path.join(op_path, file),
+ os.path.join(cwd, 'torchopt', '_lib'))
+
+ def build_cmake(self):
+ cwd = pathlib.Path().absolute()
+
+ build_temp = str(pathlib.Path(self.build_temp))
+ os.makedirs(build_temp, exist_ok=True)
+
+ config = "Debug" if self.debug else "Release"
+
+ PYTHON_INCLUDE_DIR = ";".join(self.include_dirs)
+ TORCH_INCLUDE_PATH = ";".join(cpp_extension.include_paths())
+ TORCH_LIBRARY_PATH = ";".join(cpp_extension.library_paths())
+
+ cmake_args = [
+ f"-DCMAKE_BUILD_TYPE={config}",
+ f"-DPYTHON_EXECUTABLE={sys.executable}",
+ f"-DPYTHON_INCLUDE_DIR={PYTHON_INCLUDE_DIR}",
+ f"-DTORCH_INCLUDE_PATH={TORCH_INCLUDE_PATH}",
+ f"-DTORCH_LIBRARY_PATH={TORCH_LIBRARY_PATH}",
+ ]
+
+ build_args = ["--config", config, "--", "-j4"]
+
+ os.chdir(build_temp)
+ self.spawn(["cmake", f"{str(cwd)}"] + cmake_args)
+ if not self.dry_run:
+ self.spawn(["cmake", "--build", "."] + build_args)
+ os.chdir(str(cwd))
+ self.copy(build_temp)
class download_shared():
-
- def __init__(self):
- import urllib
- dir_path = os.path.dirname(os.path.realpath(__file__))
- print(f"setup.py at {dir_path}")
- print("downloading shared libraries")
- op_urls = []
- if sys.version_info >= (3, 8) and sys.version_info < (3, 9):
- op_urls.append(
- "https://torchopt.oss-cn-beijing.aliyuncs.com/torch1_11/adam_op.cpython-38-x86_64-linux-gnu.so"
- )
- elif sys.version_info >= (3, 9) and sys.version_info < (3, 10):
- op_urls.append(
- "https://torchopt.oss-cn-beijing.aliyuncs.com/torch1_11/adam_op.cpython-39-x86_64-linux-gnu.so"
- )
-
- if len(op_urls) == 0:
- import warnings
- warnings.warn("no pre-compiled libraries for you python version")
- return
-
- for url in op_urls:
- data = urllib.request.urlopen(url)
- filename = url.rpartition('/')[-1]
- file_path = os.path.join(dir_path, 'TorchOpt', '_lib', filename)
- with open(file_path, 'wb') as f:
- f.write(data.read())
- print("shared libraries downloaded")
+ def __init__(self):
+ import urllib
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ print(f"setup.py at {dir_path}")
+ print("downloading shared libraries")
+ op_urls = []
+ if sys.version_info >= (3, 8) and sys.version_info < (3, 9):
+ op_urls.append(
+ "https://torchopt.oss-cn-beijing.aliyuncs.com/torch1_11/adam_op.cpython-38-x86_64-linux-gnu.so"
+ )
+ elif sys.version_info >= (3, 9) and sys.version_info < (3, 10):
+ op_urls.append(
+ "https://torchopt.oss-cn-beijing.aliyuncs.com/torch1_11/adam_op.cpython-39-x86_64-linux-gnu.so"
+ )
+
+ if len(op_urls) == 0:
+ import warnings
+ warnings.warn("no pre-compiled libraries for you python version")
+ return
+
+ for url in op_urls:
+ data = urllib.request.urlopen(url)
+ filename = url.rpartition('/')[-1]
+ file_path = os.path.join(dir_path, 'torchopt', '_lib', filename)
+ with open(file_path, 'wb') as f:
+ f.write(data.read())
+ print("shared libraries downloaded")
if 'build_from_source' not in sys.argv:
download_shared()
setup(
- name="TorchOpt",
- version="0.4.1",
- author="TorchOpt Contributors",
- author_email="jieren9806@gmail.com",
- description="A Jax-style optimizer.",
- license="Apache License Version 2.0",
- keywords="meta learning",
- url="https://github.com/metaopt/TorchOpt",
- packages=find_packages(),
- package_data={"": ["_lib/*.so"]},
- include_package_data=True,
- cmdclass={'build_from_source': MyBuild},
- install_requires=[
- 'jax[cpu]',
- 'torch==1.11',
- 'graphviz',
- ],
+ name="torchopt",
+ version="0.4.1",
+ author="TorchOpt Contributors",
+ author_email="jieren9806@gmail.com, xidong.feng.20@ucl.ac.uk, benjaminliu.eecs@gmail.com",
+ description="A Jax-style optimizer.",
+ license="Apache License Version 2.0",
+ keywords="meta learning",
+ url="https://github.com/metaopt/torchopt",
+ packages=find_packages(),
+ package_data={"": ["_lib/*.so"]},
+ include_package_data=True,
+ cmdclass={'build_from_source': MyBuild},
+ install_requires=[
+ 'jax[cpu]',
+ 'torch==1.11',
+ 'graphviz',
+ ],
)
diff --git a/src/adam_op/CMakeLists.txt b/src/adam_op/CMakeLists.txt
index 88991ad0..cea0371b 100644
--- a/src/adam_op/CMakeLists.txt
+++ b/src/adam_op/CMakeLists.txt
@@ -47,4 +47,4 @@ pybind11_add_module(adam_op adam_op.cpp adam_op_impl.cpp adam_op_impl.cu)
target_link_libraries(
adam_op PRIVATE
${TORCH_LIBRARIES}
- )
+)
diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp
index f8cfffce..130e3a27 100644
--- a/src/adam_op/adam_op.cpp
+++ b/src/adam_op/adam_op.cpp
@@ -21,7 +21,7 @@
#include "adam_op/adam_op_impl.cuh"
#include "adam_op/adam_op_impl.h"
-namespace TorchOpt {
+namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
@@ -110,14 +110,14 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
throw std::runtime_error("Not implemented");
}
}
-} // namespace TorchOpt
+} // namespace torchopt
PYBIND11_MODULE(adam_op, m) {
- m.def("forward_", &TorchOpt::adamForwardInplace);
- m.def("forwardMu", &TorchOpt::adamForwardMu);
- m.def("forwardNu", &TorchOpt::adamForwardNu);
- m.def("forwardUpdates", &TorchOpt::adamForwardUpdates);
- m.def("backwardMu", &TorchOpt::adamBackwardMu);
- m.def("backwardNu", &TorchOpt::adamBackwardNu);
- m.def("backwardUpdates", &TorchOpt::adamBackwardUpdates);
+ m.def("forward_", &torchopt::adamForwardInplace);
+ m.def("forwardMu", &torchopt::adamForwardMu);
+ m.def("forwardNu", &torchopt::adamForwardNu);
+ m.def("forwardUpdates", &torchopt::adamForwardUpdates);
+ m.def("backwardMu", &torchopt::adamBackwardMu);
+ m.def("backwardNu", &torchopt::adamBackwardNu);
+ m.def("backwardUpdates", &torchopt::adamBackwardUpdates);
}
diff --git a/src/adam_op/adam_op_impl.cpp b/src/adam_op/adam_op_impl.cpp
index 48427213..71807d09 100644
--- a/src/adam_op/adam_op_impl.cpp
+++ b/src/adam_op/adam_op_impl.cpp
@@ -13,16 +13,15 @@
// limitations under the License.
// ==============================================================================
-#include "adam_op/adam_op_impl.h"
-
#include
#include
#include
-#include "include/utils.h"
+#include "adam_op/adam_op_impl.h"
+#include "utils.h"
-namespace TorchOpt {
+namespace torchopt {
using std::size_t;
namespace {
template
@@ -307,4 +306,4 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
}));
return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)};
}
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/src/adam_op/adam_op_impl.cu b/src/adam_op/adam_op_impl.cu
index 0b7b4cea..c32f1ad3 100644
--- a/src/adam_op/adam_op_impl.cu
+++ b/src/adam_op/adam_op_impl.cu
@@ -18,9 +18,9 @@
#include
#include "adam_op/adam_op_impl.cuh"
-#include "include/utils.h"
+#include "utils.h"
-namespace TorchOpt {
+namespace torchopt {
namespace {
template
@@ -330,4 +330,4 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
}));
return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)};
}
-} // namespace TorchOpt
+} // namespace torchopt
diff --git a/tests/unit/high_level/test_high_level_inplace.py b/tests/unit/high_level/test_high_level_inplace.py
index 728b0158..04544ecf 100644
--- a/tests/unit/high_level/test_high_level_inplace.py
+++ b/tests/unit/high_level/test_high_level_inplace.py
@@ -17,182 +17,170 @@
import unittest
import torch
-from torch.nn import functional as F
+import torch.nn.functional as F
from torch.utils import data
from torchvision import models
-from TorchOpt import SGD, Adam, RMSProp
+import torchopt
class HighLevelInplace(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- torch.manual_seed(0)
- cls.model = models.resnet18()
- cls.model_ref = copy.deepcopy(cls.model)
- cls.model_backup = copy.deepcopy(cls.model)
-
- cls.batch_size = 2
- cls.dataset = data.TensorDataset(
- torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,))
- )
- cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
-
- cls.lr = 1e-3
-
- def setUp(self) -> None:
- torch.manual_seed(0)
- self.model = copy.deepcopy(self.model_backup)
- self.model_ref = copy.deepcopy(self.model_backup)
-
- def test_sgd(self) -> None:
- optim = SGD(self.model.parameters(), self.lr)
- optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_adam(self) -> None:
- optim = Adam(self.model.parameters(), self.lr)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_accelerated_adam_cpu(self) -> None:
- self.model
- self.model_ref
- optim = Adam(self.model.parameters(), self.lr, use_accelerated_op=True)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- xs = xs
- ys = ys
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_accelerated_adam_cuda(self) -> None:
- self.model.cuda()
- self.model_ref.cuda()
- optim = Adam(self.model.parameters(), self.lr, use_accelerated_op=True)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- xs = xs.cuda()
- ys = ys.cuda()
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_rmsprop(self) -> None:
- optim = RMSProp(
- self.model.parameters(), self.lr, decay=0.99
- ) # pytorch uses 0.99 as the default value
- optim_ref = torch.optim.RMSprop(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(
- float(mse), 0, delta=1e-4
- ) # Optax and pytorch have different implementation
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
+ @classmethod
+ def setUpClass(cls):
+ torch.manual_seed(0)
+ cls.model = models.resnet18()
+ cls.model_ref = copy.deepcopy(cls.model)
+ cls.model_backup = copy.deepcopy(cls.model)
+
+ cls.batch_size = 2
+ cls.dataset = data.TensorDataset(torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,)))
+ cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
+
+ cls.lr = 1e-3
+
+ def setUp(self) -> None:
+ torch.manual_seed(0)
+ self.model = copy.deepcopy(self.model_backup)
+ self.model_ref = copy.deepcopy(self.model_backup)
+
+ def test_sgd(self) -> None:
+ optim = torchopt.SGD(self.model.parameters(), self.lr)
+ optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_adam(self) -> None:
+ optim = torchopt.Adam(self.model.parameters(), self.lr)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_accelerated_adam_cpu(self) -> None:
+ self.model
+ self.model_ref
+ optim = torchopt.Adam(self.model.parameters(), self.lr, use_accelerated_op=True)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ xs = xs
+ ys = ys
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_accelerated_adam_cuda(self) -> None:
+ self.model.cuda()
+ self.model_ref.cuda()
+ optim = torchopt.Adam(self.model.parameters(), self.lr, use_accelerated_op=True)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ xs = xs.cuda()
+ ys = ys.cuda()
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_rmsprop(self) -> None:
+ optim = torchopt.RMSProp(
+ self.model.parameters(), self.lr, decay=0.99
+ ) # pytorch uses 0.99 as the default value
+ optim_ref = torch.optim.RMSprop(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(
+ float(mse), 0, delta=1e-4
+ ) # Optax and pytorch have different implementation
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/unit/low_level/test_low_level_inplace.py b/tests/unit/low_level/test_low_level_inplace.py
index e42209c5..c34cd324 100644
--- a/tests/unit/low_level/test_low_level_inplace.py
+++ b/tests/unit/low_level/test_low_level_inplace.py
@@ -18,190 +18,185 @@
import functorch
import torch
-from torch.nn import functional as F
+import torch.nn.functional as F
from torch.utils import data
from torchvision import models
-import TorchOpt
-from TorchOpt import adam, rmsprop, sgd
+import torchopt
class LowLevelInplace(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- torch.manual_seed(0)
- cls.model = models.resnet18()
- cls.model_ref = copy.deepcopy(cls.model)
- cls.model_backup = copy.deepcopy(cls.model)
-
- cls.batch_size = 2
- cls.dataset = data.TensorDataset(
- torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,))
- )
- cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
-
- cls.lr = 1e-3
-
- def setUp(self) -> None:
- torch.manual_seed(0)
- self.model = copy.deepcopy(self.model_backup)
- self.model_ref = copy.deepcopy(self.model_backup)
-
- def test_sgd(self) -> None:
- fun, params, buffers = functorch.make_functional_with_buffers(self.model)
- optim = sgd(self.lr)
- optim_state = optim.init(params)
- optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
-
- for xs, ys in self.loader:
- pred = fun(params, buffers, xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
-
- grad = torch.autograd.grad(loss, params)
- updates, optim_state = optim.update(grad, optim_state)
- params = TorchOpt.apply_updates(params, updates)
-
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
-
- with torch.no_grad():
- for p, p_ref in zip(params, self.model_ref.parameters()):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(buffers, self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_adam(self) -> None:
- fun, params, buffers = functorch.make_functional_with_buffers(self.model)
- optim = adam(self.lr)
- optim_state = optim.init(params)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = fun(params, buffers, xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
-
- grad = torch.autograd.grad(loss, params)
- updates, optim_state = optim.update(grad, optim_state)
- params = TorchOpt.apply_updates(params, updates)
-
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
- with torch.no_grad():
- for p, p_ref in zip(params, self.model_ref.parameters()):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(buffers, self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_accelerated_adam_cpu(self) -> None:
- self.model
- self.model_ref
- fun, params, buffers = functorch.make_functional_with_buffers(self.model)
- optim = adam(self.lr, use_accelerated_op=True)
- optim_state = optim.init(params)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- xs = xs
- ys = ys
- pred = fun(params, buffers, xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
-
- grad = torch.autograd.grad(loss, params)
- updates, optim_state = optim.update(grad, optim_state)
- params = TorchOpt.apply_updates(params, updates)
-
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
- with torch.no_grad():
- for p, p_ref in zip(params, self.model_ref.parameters()):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(buffers, self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_accelerated_adam_cuda(self) -> None:
- self.model.cuda()
- self.model_ref.cuda()
- fun, params, buffers = functorch.make_functional_with_buffers(self.model)
- optim = adam(self.lr, use_accelerated_op=True)
- optim_state = optim.init(params)
- optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- xs = xs.cuda()
- ys = ys.cuda()
- pred = fun(params, buffers, xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
-
- grad = torch.autograd.grad(loss, params)
- updates, optim_state = optim.update(grad, optim_state)
- params = TorchOpt.apply_updates(params, updates)
-
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
- with torch.no_grad():
- for p, p_ref in zip(params, self.model_ref.parameters()):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(buffers, self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
-
- def test_rmsprop(self) -> None:
- fun, params, buffers = functorch.make_functional_with_buffers(self.model)
- optim = rmsprop(
- self.lr, decay=0.99
- ) # pytorch uses 0.99 as the default value
- optim_state = optim.init(params)
- optim_ref = torch.optim.RMSprop(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = fun(params, buffers, xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
-
- grad = torch.autograd.grad(loss, params)
- updates, optim_state = optim.update(grad, optim_state)
- params = TorchOpt.apply_updates(params, updates)
-
- optim_ref.zero_grad()
- loss_ref.backward()
- optim_ref.step()
- with torch.no_grad():
- for p, p_ref in zip(params, self.model_ref.parameters()):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(
- float(mse), 0, delta=1e-4
- ) # Optax and pytorch have different implementation
- for b, b_ref in zip(buffers, self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
+ @classmethod
+ def setUpClass(cls):
+ torch.manual_seed(0)
+ cls.model = models.resnet18()
+ cls.model_ref = copy.deepcopy(cls.model)
+ cls.model_backup = copy.deepcopy(cls.model)
+
+ cls.batch_size = 2
+ cls.dataset = data.TensorDataset(torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,)))
+ cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
+
+ cls.lr = 1e-3
+
+ def setUp(self) -> None:
+ torch.manual_seed(0)
+ self.model = copy.deepcopy(self.model_backup)
+ self.model_ref = copy.deepcopy(self.model_backup)
+
+ def test_sgd(self) -> None:
+ fun, params, buffers = functorch.make_functional_with_buffers(self.model)
+ optim = torchopt.sgd(self.lr)
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
+
+ for xs, ys in self.loader:
+ pred = fun(params, buffers, xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grad = torch.autograd.grad(loss, params)
+ updates, optim_state = optim.update(grad, optim_state)
+ params = torchopt.apply_updates(params, updates)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+
+ with torch.no_grad():
+ for p, p_ref in zip(params, self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(buffers, self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_adam(self) -> None:
+ fun, params, buffers = functorch.make_functional_with_buffers(self.model)
+ optim = torchopt.adam(self.lr)
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = fun(params, buffers, xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grad = torch.autograd.grad(loss, params)
+ updates, optim_state = optim.update(grad, optim_state)
+ params = torchopt.apply_updates(params, updates)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+ with torch.no_grad():
+ for p, p_ref in zip(params, self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(buffers, self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_accelerated_adam_cpu(self) -> None:
+ self.model
+ self.model_ref
+ fun, params, buffers = functorch.make_functional_with_buffers(self.model)
+ optim = torchopt.adam(self.lr, use_accelerated_op=True)
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ xs = xs
+ ys = ys
+ pred = fun(params, buffers, xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grad = torch.autograd.grad(loss, params)
+ updates, optim_state = optim.update(grad, optim_state)
+ params = torchopt.apply_updates(params, updates)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+ with torch.no_grad():
+ for p, p_ref in zip(params, self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(buffers, self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_accelerated_adam_cuda(self) -> None:
+ self.model.cuda()
+ self.model_ref.cuda()
+ fun, params, buffers = functorch.make_functional_with_buffers(self.model)
+ optim = torchopt.adam(self.lr, use_accelerated_op=True)
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.Adam(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ xs = xs.cuda()
+ ys = ys.cuda()
+ pred = fun(params, buffers, xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grad = torch.autograd.grad(loss, params)
+ updates, optim_state = optim.update(grad, optim_state)
+ params = torchopt.apply_updates(params, updates)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+ with torch.no_grad():
+ for p, p_ref in zip(params, self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(buffers, self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
+
+ def test_rmsprop(self) -> None:
+ fun, params, buffers = functorch.make_functional_with_buffers(self.model)
+ optim = torchopt.rmsprop(self.lr, decay=0.99) # pytorch uses 0.99 as the default value
+ optim_state = optim.init(params)
+ optim_ref = torch.optim.RMSprop(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = fun(params, buffers, xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+
+ grad = torch.autograd.grad(loss, params)
+ updates, optim_state = optim.update(grad, optim_state)
+ params = torchopt.apply_updates(params, updates)
+
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ optim_ref.step()
+ with torch.no_grad():
+ for p, p_ref in zip(params, self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(
+ float(mse), 0, delta=1e-4
+ ) # Optax and pytorch have different implementation
+ for b, b_ref in zip(buffers, self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/unit/test_clip.py b/tests/unit/test_clip.py
index c129db6e..5967c9f4 100644
--- a/tests/unit/test_clip.py
+++ b/tests/unit/test_clip.py
@@ -17,69 +17,64 @@
import unittest
import torch
-from torch.nn import functional as F
+import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils import data
from torchvision import models
-import TorchOpt
-from TorchOpt import Optimizer, sgd
+import torchopt
class HighLevelInplace(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- torch.manual_seed(0)
- cls.model = models.resnet18()
- cls.model_backup = copy.deepcopy(cls.model)
- cls.model_ref = copy.deepcopy(cls.model)
+ @classmethod
+ def setUpClass(cls):
+ torch.manual_seed(0)
+ cls.model = models.resnet18()
+ cls.model_backup = copy.deepcopy(cls.model)
+ cls.model_ref = copy.deepcopy(cls.model)
- cls.batch_size = 2
- cls.dataset = data.TensorDataset(
- torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,))
- )
- cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
+ cls.batch_size = 2
+ cls.dataset = data.TensorDataset(torch.randn(2, 3, 224, 224), torch.randint(0, 1000, (2,)))
+ cls.loader = data.DataLoader(cls.dataset, cls.batch_size, False)
- cls.lr = 1e0
- cls.max_norm = 10.
+ cls.lr = 1e0
+ cls.max_norm = 10.
- def setUp(self) -> None:
- torch.manual_seed(0)
- self.model = copy.deepcopy(self.model_backup)
- self.model_ref = copy.deepcopy(self.model_backup)
+ def setUp(self) -> None:
+ torch.manual_seed(0)
+ self.model = copy.deepcopy(self.model_backup)
+ self.model_ref = copy.deepcopy(self.model_backup)
- def test_sgd(self) -> None:
- chain = TorchOpt.combine.chain(
- TorchOpt.clip.clip_grad_norm(max_norm=self.max_norm), sgd(lr=self.lr)
- )
- optim = Optimizer(self.model.parameters(), chain)
- optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
- for xs, ys in self.loader:
- pred = self.model(xs)
- pred_ref = self.model_ref(xs)
- loss = F.cross_entropy(pred, ys)
- loss_ref = F.cross_entropy(pred_ref, ys)
- optim.zero_grad()
- loss.backward()
- optim.step()
- optim_ref.zero_grad()
- loss_ref.backward()
- clip_grad_norm_(self.model_ref.parameters(), max_norm=self.max_norm)
- optim_ref.step()
+ def test_sgd(self) -> None:
+ chain = torchopt.combine.chain(
+ torchopt.clip.clip_grad_norm(max_norm=self.max_norm), torchopt.sgd(lr=self.lr)
+ )
+ optim = torchopt.Optimizer(self.model.parameters(), chain)
+ optim_ref = torch.optim.SGD(self.model_ref.parameters(), self.lr)
+ for xs, ys in self.loader:
+ pred = self.model(xs)
+ pred_ref = self.model_ref(xs)
+ loss = F.cross_entropy(pred, ys)
+ loss_ref = F.cross_entropy(pred_ref, ys)
+ optim.zero_grad()
+ loss.backward()
+ optim.step()
+ optim_ref.zero_grad()
+ loss_ref.backward()
+ clip_grad_norm_(self.model_ref.parameters(), max_norm=self.max_norm)
+ optim_ref.step()
- with torch.no_grad():
- for p, p_ref in zip(
- self.model.parameters(), self.model_ref.parameters()
- ):
- mse = F.mse_loss(p, p_ref)
- self.assertAlmostEqual(float(mse), 0)
- for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
- b = b.float() if not b.is_floating_point() else b
- b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
- mse = F.mse_loss(b, b_ref)
- self.assertAlmostEqual(float(mse), 0)
+ with torch.no_grad():
+ for p, p_ref in zip(self.model.parameters(), self.model_ref.parameters()):
+ mse = F.mse_loss(p, p_ref)
+ self.assertAlmostEqual(float(mse), 0)
+ for b, b_ref in zip(self.model.buffers(), self.model_ref.buffers()):
+ b = b.float() if not b.is_floating_point() else b
+ b_ref = b_ref.float() if not b_ref.is_floating_point() else b_ref
+ mse = F.mse_loss(b, b_ref)
+ self.assertAlmostEqual(float(mse), 0)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/tests/unit/test_schedule.py b/tests/unit/test_schedule.py
index 1e8f2831..66950050 100644
--- a/tests/unit/test_schedule.py
+++ b/tests/unit/test_schedule.py
@@ -15,35 +15,35 @@
import unittest
-import TorchOpt
+import torchopt
class TestSchedule(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.init_value = 1.
- cls.end_value = 0.
- cls.gap_value = cls.init_value - cls.end_value
- cls.transition_steps = 10
- cls.transition_begin = 1
-
- def setUp(self) -> None:
- pass
-
- def test_linear(self) -> None:
- schedule = TorchOpt.schedule.linear_schedule(
- init_value=self.init_value,
- end_value=self.end_value,
- transition_steps=self.transition_steps,
- transition_begin=self.transition_begin
- )
- for i in range(self.transition_begin, self.transition_steps):
- lr = schedule(i)
- lr_gt = self.init_value - self.gap_value * \
- (i - self.transition_begin) / self.transition_steps
- self.assertEqual(lr, lr_gt)
+ @classmethod
+ def setUpClass(cls):
+ cls.init_value = 1.
+ cls.end_value = 0.
+ cls.gap_value = cls.init_value - cls.end_value
+ cls.transition_steps = 10
+ cls.transition_begin = 1
+
+ def setUp(self) -> None:
+ pass
+
+ def test_linear(self) -> None:
+ schedule = torchopt.schedule.linear_schedule(
+ init_value=self.init_value,
+ end_value=self.end_value,
+ transition_steps=self.transition_steps,
+ transition_begin=self.transition_begin
+ )
+ for i in range(self.transition_begin, self.transition_steps):
+ lr = schedule(i)
+ lr_gt = self.init_value - self.gap_value * \
+ (i - self.transition_begin) / self.transition_steps
+ self.assertEqual(lr, lr_gt)
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
diff --git a/torchopt/__init__.py b/torchopt/__init__.py
new file mode 100644
index 00000000..6672c724
--- /dev/null
+++ b/torchopt/__init__.py
@@ -0,0 +1,64 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TorchOpt: a high-performance optimizer library built upon PyTorch."""
+
+from torchopt._src import (
+ accelerated_op_available,
+ clip,
+ combine,
+ hook,
+ schedule,
+ visual,
+)
+from torchopt._src.alias import adam, rmsprop, sgd
+from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, meta
+from torchopt._src.optimizer.meta import (
+ MetaAdam,
+ MetaOptimizer,
+ MetaRMSProp,
+ MetaSGD,
+)
+from torchopt._src.update import apply_updates
+from torchopt._src.utils import (
+ extract_state_dict,
+ recover_state_dict,
+ stop_gradient,
+)
+
+__version__ = "0.4.1"
+
+__all__ = [
+ "accelerated_op_available",
+ "clip",
+ "combine",
+ "hook",
+ "schedule",
+ "visual",
+ "adam",
+ "rmsprop",
+ "sgd",
+ "Optimizer",
+ "SGD",
+ "Adam",
+ "RMSProp",
+ "MetaOptimizer",
+ "MetaSGD",
+ "MetaAdam",
+ "MetaRMSProp",
+ "apply_updates",
+ "extract_state_dict",
+ "recover_state_dict",
+ "stop_gradient",
+]
diff --git a/TorchOpt/_lib/__init__.py b/torchopt/_lib/__init__.py
similarity index 100%
rename from TorchOpt/_lib/__init__.py
rename to torchopt/_lib/__init__.py
diff --git a/torchopt/_lib/adam_op.py b/torchopt/_lib/adam_op.py
new file mode 100644
index 00000000..ca10e621
--- /dev/null
+++ b/torchopt/_lib/adam_op.py
@@ -0,0 +1,57 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================\
+
+from typing import Tuple
+
+import torch
+
+
+def forward_(
+ updates: torch.Tensor, mu: torch.Tensor, nu: torch.Tensor, b1: float, b2: float, eps: float,
+ eps_root: float, count: int
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ ...
+
+
+def forwardMu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor:
+ ...
+
+
+def forwardNu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor:
+ ...
+
+
+def forwardUpdates(
+ new_mu: torch.Tensor, new_nu: torch.Tensor, b1: float, b2: float, eps: float, eps_root: float,
+ count: int
+) -> torch.Tensor:
+ ...
+
+
+def backwardMu(dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor,
+ b1: float) -> Tuple[torch.Tensor, torch.Tensor]:
+ ...
+
+
+def backwardNu(dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor,
+ b2: float) -> Tuple[torch.Tensor, torch.Tensor]:
+ ...
+
+
+def backwardUpdates(
+ dupdates: torch.Tensor, updates: torch.Tensor, new_mu: torch.Tensor, new_nu: torch.Tensor,
+ b1: float, b2: float, count: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ ...
diff --git a/TorchOpt/_src/__init__.py b/torchopt/_src/__init__.py
similarity index 91%
rename from TorchOpt/_src/__init__.py
rename to torchopt/_src/__init__.py
index 522a892f..75b3cf8d 100644
--- a/TorchOpt/_src/__init__.py
+++ b/torchopt/_src/__init__.py
@@ -13,4 +13,4 @@
# limitations under the License.
# ==============================================================================
-from TorchOpt._src.accelerated_op import accelerated_op_available
+from torchopt._src.accelerated_op import accelerated_op_available
diff --git a/TorchOpt/_src/accelerated_op/__init__.py b/torchopt/_src/accelerated_op/__init__.py
similarity index 61%
rename from TorchOpt/_src/accelerated_op/__init__.py
rename to torchopt/_src/accelerated_op/__init__.py
index d6fa1792..ab494d23 100644
--- a/TorchOpt/_src/accelerated_op/__init__.py
+++ b/torchopt/_src/accelerated_op/__init__.py
@@ -13,20 +13,20 @@
# limitations under the License.
# ==============================================================================
-from TorchOpt._src.accelerated_op.adam_op import AdamOp
+from torchopt._src.accelerated_op.adam_op import AdamOp
def accelerated_op_available(devices=None):
- import torch
- op = AdamOp()
- if devices is None:
- devices = [torch.device("cuda"), torch.device("cpu")]
- elif isinstance(devices, torch.device):
- devices = [devices]
- try:
- for device in devices:
- updates = torch.tensor(1., device=device)
- op(updates, updates, updates, 1)
- return True
- except:
- return False
+ import torch
+ op = AdamOp()
+ if devices is None:
+ devices = [torch.device("cuda"), torch.device("cpu")]
+ elif isinstance(devices, torch.device):
+ devices = [devices]
+ try:
+ for device in devices:
+ updates = torch.tensor(1., device=device)
+ op(updates, updates, updates, 1)
+ return True
+ except BaseException:
+ return False
diff --git a/TorchOpt/_src/accelerated_op/adam_op/__init__.py b/torchopt/_src/accelerated_op/adam_op/__init__.py
similarity index 91%
rename from TorchOpt/_src/accelerated_op/adam_op/__init__.py
rename to torchopt/_src/accelerated_op/adam_op/__init__.py
index 95a47453..d1203e92 100644
--- a/TorchOpt/_src/accelerated_op/adam_op/__init__.py
+++ b/torchopt/_src/accelerated_op/adam_op/__init__.py
@@ -13,4 +13,4 @@
# limitations under the License.
# ==============================================================================
-from TorchOpt._src.accelerated_op.adam_op.AdamOp import AdamOp
+from torchopt._src.accelerated_op.adam_op.adam_op import AdamOp
diff --git a/torchopt/_src/accelerated_op/adam_op/adam_op.py b/torchopt/_src/accelerated_op/adam_op/adam_op.py
new file mode 100644
index 00000000..94098520
--- /dev/null
+++ b/torchopt/_src/accelerated_op/adam_op/adam_op.py
@@ -0,0 +1,116 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Any
+
+import torch
+
+from torchopt._lib import adam_op
+
+
+class AdamOp(object):
+
+ class MuOp(torch.autograd.Function):
+
+ @staticmethod
+ def jvp(ctx: Any, *grad_inputs: Any) -> Any:
+ pass
+
+ @staticmethod
+ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+ updates, mu, b1 = args
+ new_mu = adam_op.forwardMu(updates, mu, b1)
+ ctx.save_for_backward(updates, mu)
+ ctx.b1 = b1
+ return new_mu
+
+ @staticmethod
+ def backward(ctx: Any, *args: Any) -> Any:
+ dmu = args[0]
+ updates, mu = ctx.saved_tensors
+ b1 = ctx.b1
+ result = adam_op.backwardMu(dmu, updates, mu, b1)
+ return result[0], result[1], None
+
+ class NuOp(torch.autograd.Function):
+
+ @staticmethod
+ def jvp(ctx: Any, *grad_inputs: Any) -> Any:
+ pass
+
+ @staticmethod
+ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+ updates, nu, b2 = args
+ new_nu = adam_op.forwardNu(updates, nu, b2)
+ ctx.save_for_backward(updates, nu)
+ ctx.b2 = b2
+ return new_nu
+
+ @staticmethod
+ def backward(ctx: Any, *args: Any) -> Any:
+ dnu = args[0]
+ updates, nu = ctx.saved_tensors
+ b2 = ctx.b2
+ result = adam_op.backwardNu(dnu, updates, nu, b2)
+ return result[0], result[1], None
+
+ class UpdatesOp(torch.autograd.Function):
+
+ @staticmethod
+ def jvp(ctx: Any, *grad_inputs: Any) -> Any:
+ pass
+
+ @staticmethod
+ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
+ new_mu, new_nu, (b1, b2, eps, eps_root, count) = args
+ new_updates = adam_op.forwardUpdates(new_mu, new_nu, b1, b2, eps, eps_root, count)
+ ctx.save_for_backward(new_updates, new_mu, new_nu)
+ ctx.others = (b1, b2, eps, eps_root, count)
+ return new_updates
+
+ @staticmethod
+ def backward(ctx: Any, *args: Any) -> Any:
+ dupdates = args[0]
+ updates, new_mu, new_nu = ctx.saved_tensors
+ b1, b2, eps, eps_root, count = ctx.others
+ result = adam_op.backwardUpdates(dupdates, updates, new_mu, new_nu, b1, b2, count)
+ return result[0], result[1], None
+
+ def __init__(self, b1=0.9, b2=0.999, eps=1e-8, eps_root=0., inplace=True):
+ self.b1 = b1
+ self.b2 = b2
+ self.eps = eps
+ self.eps_root = eps_root
+ self.inplace = inplace
+
+ def __call__(self, mu, nu, updates, count):
+ if updates is None:
+ return mu, nu, None
+ if updates.is_cuda:
+ current_device = torch.cuda.current_device()
+ torch.cuda.set_device(updates.device)
+ if self.inplace:
+ new_updates, new_mu, new_nu = adam_op.forward_(
+ updates, mu, nu, self.b1, self.b2, self.eps, self.eps_root, count
+ )
+ else:
+ new_mu = self.MuOp.apply(updates, mu, self.b1)
+ new_nu = self.NuOp.apply(updates, nu, self.b2)
+ new_updates = self.UpdatesOp.apply(
+ new_mu, new_nu, (self.b1, self.b2, self.eps, self.eps_root, count)
+ )
+ if updates.is_cuda:
+ torch.cuda.set_device(current_device)
+ return new_mu, new_nu, new_updates
diff --git a/torchopt/_src/alias.py b/torchopt/_src/alias.py
new file mode 100644
index 00000000..a29adca1
--- /dev/null
+++ b/torchopt/_src/alias.py
@@ -0,0 +1,205 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Optional
+
+import jax
+
+from torchopt._src import base, combine, transform
+from torchopt._src.typing import ScalarOrSchedule
+
+
+def _scale_by_lr(lr: ScalarOrSchedule, flip_sign=True):
+ m = -1 if flip_sign else 1
+ if callable(lr):
+
+ def schedule_wrapper(count):
+
+ def f(scaled_lr):
+ return m * scaled_lr
+
+ return jax.tree_map(f, lr(count)) # type: ignore
+
+ return transform.scale_by_schedule(schedule_wrapper)
+ return transform.scale(m * lr)
+
+
+def adam(
+ lr: ScalarOrSchedule,
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ moment_requires_grad: bool = False,
+ use_accelerated_op: bool = False
+) -> base.GradientTransformation:
+ """The classic Adam optimizer.
+
+ Adam is an SGD variant with learning rate adaptation. The `lr`
+ used for each weight is computed from estimates of first- and second-order
+ moments of the gradients (using suitable exponential moving averages).
+
+ References:
+ Kingma et al, 2014: https://arxiv.org/abs/1412.6980
+
+ Args:
+ lr:
+ This is a fixed global scaling factor.
+ b1:
+ The exponential decay rate to track the first moment of past gradients.
+ b2:
+ The exponential decay rate to track the second moment of past gradients.
+ eps:
+ A small constant applied to denominator outside of the square root
+ (as in the Adam paper) to avoid dividing by zero when rescaling.
+ eps_root: (default `0`)
+ A small constant applied to denominator inside the square root (as
+ in RMSProp), to avoid dividing by zero when rescaling. This is needed
+ for example when computing (meta-)gradients through Adam.
+ moment_requires_grad: (default `False`)
+ If True the momentums will be created with flag `requires_grad=True`,
+ this flag is often used in Meta Learning algorithms.
+ use_accelerated_op: (default `False`)
+ If True use our implemented fused operator.
+
+ Returns:
+ The corresponding `GradientTransformation` instance.
+ """
+
+ adam_inst = transform.scale_by_accelerated_adam if use_accelerated_op else transform.scale_by_adam
+ return combine.chain(
+ adam_inst(
+ b1=b1, b2=b2, eps=eps, eps_root=eps_root, moment_requires_grad=moment_requires_grad
+ ),
+ _scale_by_lr(lr),
+ )
+
+
+def sgd(
+ lr: ScalarOrSchedule,
+ momentum: Optional[float] = None,
+ nesterov: bool = False,
+ moment_requires_grad: bool = False,
+) -> base.GradientTransformation:
+ """A canonical Stochastic Gradient Descent optimiser.
+
+ This implements stochastic gradient descent. It also includes support for
+ momentum, and nesterov acceleration, as these are standard practice when
+ using stochastic gradient descent to train deep neural networks.
+
+ References:
+ Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
+
+ Args:
+ lr:
+ This is a fixed global scaling factor.
+ momentum: (default `None`)
+ The `decay` rate used by the momentum term, when it is set to `None`,
+ then momentum is not used at all.
+ nesterov (default `False`):
+ Whether nesterov momentum is used.
+ moment_requires_grad: (default `False`)
+ If True the momentums will be created with flag `requires_grad=True`,
+ this flag is often used in Meta-Learning algorithms.
+
+ Returns:
+ A `GradientTransformation` instance.
+ """
+
+ return combine.chain(
+ (
+ transform.trace(
+ decay=momentum, nesterov=nesterov, moment_requires_grad=moment_requires_grad
+ ) if momentum is not None else base.identity()
+ ), _scale_by_lr(lr)
+ )
+
+
+def rmsprop(
+ lr: ScalarOrSchedule,
+ decay: float = 0.9,
+ eps: float = 1e-8,
+ initial_scale: float = 0.,
+ centered: bool = False,
+ momentum: Optional[float] = None,
+ nesterov: bool = False
+) -> base.GradientTransformation:
+ """A flexible RMSProp optimizer.
+ RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
+ used for each weight is scaled by a suitable estimate of the magnitude of the
+ gradients on previous steps. Several variants of RMSProp can be found
+ in the literature. This alias provides an easy to configure RMSProp
+ optimizer that can be used to switch between several of these variants.
+
+ References:
+ Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
+ Graves, 2013: https://arxiv.org/abs/1308.0850
+
+ Args:
+ learning_rate:
+ This is a fixed global scaling factor.
+ decay:
+ The decay used to track the magnitude of previous gradients.
+ eps:
+ A small numerical constant to avoid dividing by zero when rescaling.
+ initial_scale: (default `0.`)
+ Initialization of accumulators tracking the magnitude of previous
+ updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results
+ from a paper, verify the value used by the authors.
+ centered: (default `False`)
+ Whether the second moment or the variance of the past gradients is
+ used to rescale the latest gradients.
+ momentum: (default `None`)
+ The `decay` rate used by the momentum term, when it is set to `None`,
+ then momentum is not used at all.
+ nesterov (default `False`):
+ Whether nesterov momentum is used.
+
+ Returns:
+ The corresponding `GradientTransformation` instance.
+ """
+
+ if centered:
+ return combine.chain(
+ transform.scale_by_stddev(decay=decay, eps=eps, initial_scale=initial_scale),
+ _scale_by_lr(lr), (
+ transform.trace(decay=momentum, nesterov=nesterov)
+ if momentum is not None else base.identity()
+ )
+ )
+ return combine.chain(
+ transform.scale_by_rms(decay=decay, eps=eps, initial_scale=initial_scale), _scale_by_lr(lr),
+ (
+ transform.trace(decay=momentum, nesterov=nesterov)
+ if momentum is not None else base.identity()
+ )
+ )
diff --git a/torchopt/_src/base.py b/torchopt/_src/base.py
new file mode 100644
index 00000000..03cd0b97
--- /dev/null
+++ b/torchopt/_src/base.py
@@ -0,0 +1,151 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/base.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Callable, NamedTuple, Tuple
+
+import typing_extensions
+
+from torchopt._src import typing
+
+OptState = typing.TensorTree # States are arbitrary nests of `torch.Tensor`.
+# Parameters are arbitrary nests of `torch.Tensor`.
+Params = typing.TensorTree
+Updates = Params # Gradient updates are of the same type as parameters.
+
+Schedule = Callable[[typing.Numeric], typing.Numeric]
+
+
+class EmptyState(NamedTuple):
+ """An empty state for the simplest stateless transformations."""
+
+
+class TransformInitFn(typing_extensions.Protocol):
+ """A callable type for the `init` step of a `GradientTransformation`.
+
+ The `init` step takes a tree of `params` and uses these to construct an
+ arbitrary structured initial `state` for the gradient transformation. This
+ may hold statistics of the past updates or any other non static information.
+ """
+
+ def __call__(self, params: Params) -> OptState:
+ """The `init` function.
+
+ Args:
+ params:
+ The initial value of the parameters.
+
+ Returns:
+ The initial state of the gradient transformation.
+ """
+ ...
+
+
+class TransformUpdateFn(typing_extensions.Protocol):
+ """A callable type for the `update` step of a `GradientTransformation`.
+
+ The `update` step takes a tree of candidate parameter `updates` (e.g. their
+ gradient with respect to some loss), an arbitrary structured `state`, and the
+ current `params` of the model being optimized. The `params` argument is
+ optional, it must however be provided when using transformations that require
+ access to the current values of the parameters.
+ """
+
+ def __call__(self,
+ updates: Updates,
+ state: OptState,
+ inplace: bool = True) -> Tuple[Updates, OptState]:
+ """The `update` function.
+
+ Args:
+ updates:
+ A tree of candidate updates.
+ state:
+ The state of the gradient transformation.
+ inplace: (optional)
+ If true, modify updates and state using inplace operations.
+
+ Returns:
+ The transformed updates, and the updated state.
+ """
+ ...
+
+
+class GradientTransformation(NamedTuple):
+ """A pair of pure functions implementing a gradient transformation.
+
+ TorchOpt optimizers are all implemented as _gradient transformations_ like
+ Optax. A gradient transformation is defined to be a pair of pure functions,
+ which are combined together in a `NamedTuple` so that they can be referred
+ to by name.
+
+ Since gradient transformations do not contain any internal state, all stateful
+ optimizer properties (such as the current step count when using optimizer
+ schedules, or momentum values) are passed through gradient transformations by
+ using the optimizer _state_ pytree. Each time a gradient transformation is
+ applied, the state is computed and returned, ready to be passed to the next
+ call to the gradient transformation.
+
+ Attributes:
+ init:
+ A pure function which, when called with an example instance of the
+ parameters whose gradients will be transformed, returns a pytree
+ containing the initial value for the optimizer state.
+ update:
+ A pure function which takes as input a pytree of updates (with the
+ same tree structure as the original params pytree passed to init),
+ the previous optimizer state (which may have been initialized using
+ the init function), and optionally the inplace flag. The update
+ function then returns the computed gradient updates, and a updates
+ optimizer state. If the inplace flag is true, the output results are
+ the same instance as the input.
+ """
+
+ init: TransformInitFn
+ update: TransformUpdateFn
+
+
+def identity() -> GradientTransformation:
+ """Stateless identity transformation that leaves input gradients untouched.
+
+ This function passes through the *gradient updates* unchanged.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(_):
+ return EmptyState()
+
+ def update_fn(updates, state, inplace=False):
+ return updates, state
+
+ return GradientTransformation(init_fn, update_fn)
diff --git a/torchopt/_src/clip.py b/torchopt/_src/clip.py
new file mode 100644
index 00000000..c5da0812
--- /dev/null
+++ b/torchopt/_src/clip.py
@@ -0,0 +1,88 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py
+# ==============================================================================
+
+import jax
+import torch
+from torch._six import inf
+
+from torchopt._src import base
+
+ClipState = base.EmptyState
+
+
+def clip_grad_norm(
+ max_norm: float,
+ norm_type: float = 2.,
+ error_if_nonfinite: bool = False
+) -> base.GradientTransformation:
+ """Clips gradient norm of an iterable of parameters.
+
+ Args:
+ max_delta:
+ The maximum absolute value for each element in the update.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ del params
+ return ClipState()
+
+ def update_fn(updates, state, inplace=True):
+ available_updates = []
+ for g in updates:
+ if g is not None:
+ available_updates.append(g)
+ if len(available_updates) == 0:
+ return torch.tensor(0.)
+ device = available_updates[0].device
+ with torch.no_grad():
+ if norm_type == inf:
+ norms = [p.abs().max().to(device) for p in available_updates]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(
+ torch.stack([torch.norm(p, norm_type).to(device) for p in available_updates]),
+ norm_type
+ )
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f'The total norm of order {norm_type} for gradients from '
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
+ 'this error and scale the gradients by the non-finite norm anyway, '
+ 'set `error_if_nonfinite=False`'
+ )
+ clip_coef = max_norm / (float(total_norm) + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = min(clip_coef, 1.)
+ if inplace:
+
+ def f(g):
+ return g.mul_(clip_coef_clamped) if g is not None else None
+ else:
+
+ def f(g):
+ return g.mul(clip_coef_clamped) if g is not None else None
+
+ new_updates = jax.tree_map(f, updates)
+ return new_updates, state
+
+ return base.GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/combine.py b/torchopt/_src/combine.py
similarity index 58%
rename from TorchOpt/_src/combine.py
rename to torchopt/_src/combine.py
index 396a2bc4..081421c9 100644
--- a/TorchOpt/_src/combine.py
+++ b/torchopt/_src/combine.py
@@ -30,39 +30,40 @@
# limitations under the License.
# ==============================================================================
-from TorchOpt._src import base
+from torchopt._src import base
def chain(*args: base.GradientTransformation) -> base.GradientTransformation:
- """Applies a list of chainable update transformations.
+ """Applies a list of chainable update transformations.
- Given a sequence of chainable transforms, `chain` returns an `init_fn`
- that constructs a `state` by concatenating the states of the individual
- transforms, and returns an `update_fn` which chains the update transformations
- feeding the appropriate state to each.
+ Given a sequence of chainable transforms, `chain` returns an `init_fn`
+ that constructs a `state` by concatenating the states of the individual
+ transforms, and returns an `update_fn` which chains the update transformations
+ feeding the appropriate state to each.
- Args:
- *args: a sequence of chainable (init_fn, update_fn) tuples.
+ Args:
+ *args:
+ A sequence of chainable (init_fn, update_fn) tuples.
- Returns:
- A single (init_fn, update_fn) tuple.
- """
+ Returns:
+ A single (init_fn, update_fn) tuple.
+ """
- init_fns, update_fns = zip(*args)
+ init_fns, update_fns = zip(*args)
- def init_fn(params):
- return tuple(fn(params) for fn in init_fns)
+ def init_fn(params):
+ return tuple(fn(params) for fn in init_fns)
- def update_fn(updates, state, inplace=True):
- if len(update_fns) != len(state):
- raise ValueError(
- 'The number of updates and states has to be the same in '
- 'chain! Make sure you have called init first!'
- )
- new_state = []
- for s, fn in zip(state, update_fns):
- updates, new_s = fn(updates, s, inplace)
- new_state.append(new_s)
- return updates, tuple(new_state)
+ def update_fn(updates, state, inplace=True):
+ if len(update_fns) != len(state):
+ raise ValueError(
+ 'The number of updates and states has to be the same in '
+ 'chain! Make sure you have called init first!'
+ )
+ new_state = []
+ for s, fn in zip(state, update_fns):
+ updates, new_s = fn(updates, s, inplace)
+ new_state.append(new_s)
+ return updates, tuple(new_state)
- return base.GradientTransformation(init_fn, update_fn)
+ return base.GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/hook.py b/torchopt/_src/hook.py
similarity index 56%
rename from TorchOpt/_src/hook.py
rename to torchopt/_src/hook.py
index 93ca980b..77ae1bd0 100644
--- a/TorchOpt/_src/hook.py
+++ b/torchopt/_src/hook.py
@@ -16,31 +16,31 @@
import jax
import torch
-from TorchOpt._src.base import EmptyState, GradientTransformation
+from torchopt._src.base import EmptyState, GradientTransformation
def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
- return torch.where(torch.isnan(g), torch.zeros_like(g), g)
+ return torch.where(torch.isnan(g), torch.zeros_like(g), g)
def register_hook(hook) -> GradientTransformation:
- """Stateless identity transformation that leaves input gradients untouched.
+ """Stateless identity transformation that leaves input gradients untouched.
- This function passes through the *gradient updates* unchanged.
+ This function passes through the *gradient updates* unchanged.
- Returns:
- An (init_fn, update_fn) tuple.
- """
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
- def init_fn(_):
- return EmptyState()
+ def init_fn(_):
+ return EmptyState()
- def update_fn(updates, state, inplace=False):
+ def update_fn(updates, state, inplace=False):
- def f(g):
- return g.register_hook(hook) if g is not None else None
+ def f(g):
+ return g.register_hook(hook) if g is not None else None
- jax.tree_map(f, updates)
- return updates, state
+ jax.tree_map(f, updates)
+ return updates, state
- return GradientTransformation(init_fn, update_fn)
+ return GradientTransformation(init_fn, update_fn)
diff --git a/torchopt/_src/optimizer/__init__.py b/torchopt/_src/optimizer/__init__.py
new file mode 100644
index 00000000..3d07bcdd
--- /dev/null
+++ b/torchopt/_src/optimizer/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from torchopt._src.optimizer import meta
+from torchopt._src.optimizer.adam import Adam
+from torchopt._src.optimizer.base import Optimizer
+from torchopt._src.optimizer.rmsprop import RMSProp
+from torchopt._src.optimizer.sgd import SGD
diff --git a/torchopt/_src/optimizer/adam.py b/torchopt/_src/optimizer/adam.py
new file mode 100644
index 00000000..1b0ce395
--- /dev/null
+++ b/torchopt/_src/optimizer/adam.py
@@ -0,0 +1,55 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from torchopt._src.alias import adam
+from torchopt._src.optimizer.base import Optimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class Adam(Optimizer):
+ """A canonical Stochastic Gradient Descent optimizer."""
+
+ def __init__(
+ self,
+ params,
+ lr: ScalarOrSchedule,
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ use_accelerated_op: bool = False
+ ):
+ """The `init` function.
+
+ Args:
+ params (iterable):
+ An iterable of `torch.Tensor`s. Specifies what Tensors should be
+ optimized.
+ args:
+ Other arguments see `alias.sgd`.
+ """
+
+ super().__init__(
+ params,
+ adam(
+ lr=lr,
+ b1=b1,
+ b2=b2,
+ eps=eps,
+ eps_root=eps_root,
+ moment_requires_grad=False,
+ use_accelerated_op=use_accelerated_op
+ )
+ )
diff --git a/torchopt/_src/optimizer/base.py b/torchopt/_src/optimizer/base.py
new file mode 100644
index 00000000..82f5284b
--- /dev/null
+++ b/torchopt/_src/optimizer/base.py
@@ -0,0 +1,127 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Iterable
+
+import jax
+import torch
+
+from torchopt._src.base import GradientTransformation
+from torchopt._src.update import apply_updates
+
+
+class Optimizer(object):
+ """A high-level base class that has the similar with `torch.optim.Optimizer`."""
+
+ def __init__(self, params: Iterable, impl: GradientTransformation):
+ """The `init` function.
+
+ Args:
+ params (iterable):
+ An iterable of `torch.Tensor`s. Specifies what Tensors should be
+ optimized.
+ impl (GradientTransformation):
+ A low level optimizer function, it could be a optimizer function
+ provided by `alias.py` or a customized `chain` provided by
+ `combine.py`.
+ Note that use `MetaOptimizer(sgd())` or `MetaOptimizer(chain(sgd()))`
+ is equivalent to `SGD`.
+ """
+
+ if not isinstance(params, list):
+ params = list(params)
+ self.impl = impl
+ self.param_groups = [] # type: ignore
+ self.param_tree_groups = [] # type: ignore
+ self.state_groups = [] # type: ignore
+ self.add_param_group(params)
+
+ def zero_grad(self, set_to_none: bool = False):
+ """Sets the gradients of all optimized `torch.Tensor`s to zero.
+
+ The behavior is similar to `torch.optim.Optimizer.zero_grad`.
+
+ Args:
+ set_to_none (bool):
+ Instead of setting to zero, set the grads to None.
+ """
+
+ for group in self.param_groups:
+ if set_to_none:
+
+ def f(p):
+ p.grad = None
+ return None
+
+ else:
+
+ def f(p):
+ if p.grad is None:
+ return None
+ if p.grad.grad_fn is not None:
+ p.grad.detach_()
+ else:
+ p.grad.requires_grad_(False)
+ p.grad.zero_()
+ return None
+
+ jax.tree_map(f, group)
+
+ def state_dict(self):
+ """Returns the state of the optimizer."""
+
+ return self.state_groups
+
+ def load_state_dict(self, state_dict):
+ """Loads the optimizer state.
+
+ Args:
+ state_dict (dict):
+ Optimizer state. Should be an object returned from a call to :meth:`state_dict`.
+ """
+
+ self.state_groups = state_dict
+
+ def step(self, closure=None):
+ """Performs a single optimization step (parameter update).
+
+ The behavior is similar to `torch.optim.Optimizer.step`.
+
+ Args:
+ closure (callable, optional):
+ A closure that reevaluates the model and returns the loss.
+ """
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ def f(p):
+ return p.grad
+
+ for param, state in zip(self.param_groups, self.state_groups):
+ grad = jax.tree_map(f, param)
+ updates, _ = self.impl.update(grad, state)
+ apply_updates(param, updates)
+
+ return loss
+
+ def add_param_group(self, params):
+ params, tree = jax.tree_flatten(params)
+ params = tuple(params)
+ self.param_groups.append(params)
+ self.param_tree_groups.append(tree)
+ self.state_groups.append(self.impl.init(params))
diff --git a/torchopt/_src/optimizer/meta/__init__.py b/torchopt/_src/optimizer/meta/__init__.py
new file mode 100644
index 00000000..86fcb3b3
--- /dev/null
+++ b/torchopt/_src/optimizer/meta/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from torchopt._src.optimizer.meta.adam import MetaAdam
+from torchopt._src.optimizer.meta.base import MetaOptimizer
+from torchopt._src.optimizer.meta.rmsprop import MetaRMSProp
+from torchopt._src.optimizer.meta.sgd import MetaSGD
diff --git a/torchopt/_src/optimizer/meta/adam.py b/torchopt/_src/optimizer/meta/adam.py
new file mode 100644
index 00000000..d699b3b5
--- /dev/null
+++ b/torchopt/_src/optimizer/meta/adam.py
@@ -0,0 +1,56 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from torchopt._src.alias import adam
+from torchopt._src.optimizer.meta.base import MetaOptimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class MetaAdam(MetaOptimizer):
+ """The classic Adam optimizer."""
+
+ def __init__(
+ self,
+ net,
+ lr: ScalarOrSchedule,
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ moment_requires_grad: bool = True,
+ use_accelerated_op: bool = False
+ ):
+ """The `init` function.
+
+ Args:
+ net (nn.Module):
+ A network whose parameters should be optimized.
+ args:
+ Other arguments see `alias.adam`, here we set `moment_requires_grad=True`
+ to make tensors like momentum be differentiable.
+ """
+
+ super().__init__(
+ net,
+ adam(
+ lr=lr,
+ b1=b1,
+ b2=b2,
+ eps=eps,
+ eps_root=eps_root,
+ moment_requires_grad=moment_requires_grad,
+ use_accelerated_op=use_accelerated_op
+ )
+ )
diff --git a/torchopt/_src/optimizer/meta/base.py b/torchopt/_src/optimizer/meta/base.py
new file mode 100644
index 00000000..486ff15d
--- /dev/null
+++ b/torchopt/_src/optimizer/meta/base.py
@@ -0,0 +1,94 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import jax
+import torch
+import torch.nn as nn
+
+from torchopt._src.base import GradientTransformation
+from torchopt._src.update import apply_updates
+
+
+class MetaOptimizer(object):
+ """A high-level optimizer base class for meta learning."""
+
+ def __init__(self, net: nn.Module, impl: GradientTransformation):
+ """
+ Args:
+ net (nn.Module):
+ A network whose parameters should be optimized.
+ impl (GradientTransformation):
+ A low level optimizer function, it could be a optimizer function
+ provided by `alias.py` or a customerized `chain` provided by
+ `combine.py`.
+ Note that use `MetaOptimizer(sgd(moment_requires_grad=True))`
+ or `MetaOptimizer(chain(sgd(moment_requires_grad=True))) is
+ equivalent to `MetaSGD`.
+ """
+
+ self.impl = impl
+ self.param_containers_groups = [] # type: ignore
+ self.state_groups = [] # type: ignore
+
+ self.add_param_group(net)
+
+ def step(self, loss: torch.Tensor):
+ """Compute the gradients of the loss to the network parameters and update network parameters.
+
+ Graph of the derivative will be constructed, allowing to compute higher order derivative products.
+ We use the differentiable optimizer (pass argument inplace=False) to scale the gradients and update
+ the network parameters without modifying tensors in-place.
+
+ Args:
+ loss (torch.Tensor):
+ The loss that is used to compute the gradients to the network parameters.
+ """
+
+ # step parameter only
+ for idx, (state, param_containers) in enumerate(
+ zip(self.state_groups, self.param_containers_groups)
+ ):
+ flatten_params, containers_tree = jax.tree_util.tree_flatten(param_containers)
+ flatten_params = tuple(flatten_params)
+ grad = torch.autograd.grad(loss, flatten_params, create_graph=True, allow_unused=True)
+ updates, state = self.impl.update(grad, state, False)
+ self.state_groups[idx] = state
+ new_params = apply_updates(flatten_params, updates, inplace=False)
+ unflatten_new_params = containers_tree.unflatten(new_params)
+ for container, unflatten_param in zip(param_containers, unflatten_new_params):
+ container.update(unflatten_param)
+
+ def add_param_group(self, net):
+ from torchopt._src.utils import _extract_container
+
+ net_container = _extract_container(net, with_buffer=False)
+ flatten_param, _ = jax.tree_util.tree_flatten(net_container)
+ flatten_param = tuple(flatten_param)
+ optim_state = self.impl.init(flatten_param)
+ self.state_groups.append(optim_state)
+ self.param_containers_groups.append(net_container)
+
+ def state_dict(self):
+ """Extract the references of the optimizer states.
+
+ Note that the states are references, so any in-place operations will
+ change the states inside `MetaOptimizer` at the same time.
+ """
+
+ out_groups = tuple(group for group in self.state_groups)
+ return out_groups
+
+ def load_state_dict(self, state_dict):
+ self.state_groups = list(group for group in state_dict)
diff --git a/torchopt/_src/optimizer/meta/rmsprop.py b/torchopt/_src/optimizer/meta/rmsprop.py
new file mode 100644
index 00000000..eb742b04
--- /dev/null
+++ b/torchopt/_src/optimizer/meta/rmsprop.py
@@ -0,0 +1,58 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Union
+
+from torchopt._src.alias import rmsprop
+from torchopt._src.optimizer.meta.base import MetaOptimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class MetaRMSProp(MetaOptimizer):
+ """The classic RMSProp optimizer."""
+
+ def __init__(
+ self,
+ net,
+ lr: ScalarOrSchedule,
+ decay: float = 0.9,
+ eps: float = 1e-8,
+ initial_scale: float = 0.,
+ centered: bool = False,
+ momentum: Union[float, None] = None,
+ nesterov: bool = False
+ ):
+ """The `init` function.
+
+ Args:
+ net (nn.Module):
+ A network whose parameters should be optimized.
+ args:
+ Other arguments see `alias.adam`, here we set `moment_requires_grad=True`
+ to make tensors like momentum be differentiable.
+ """
+
+ super().__init__(
+ net,
+ rmsprop(
+ lr=lr,
+ decay=decay,
+ eps=eps,
+ initial_scale=initial_scale,
+ centered=centered,
+ momentum=momentum,
+ nesterov=nesterov
+ )
+ )
diff --git a/torchopt/_src/optimizer/meta/sgd.py b/torchopt/_src/optimizer/meta/sgd.py
new file mode 100644
index 00000000..bbd57b46
--- /dev/null
+++ b/torchopt/_src/optimizer/meta/sgd.py
@@ -0,0 +1,54 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Union
+
+import torch.nn as nn
+
+from torchopt._src.alias import sgd
+from torchopt._src.optimizer.meta.base import MetaOptimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class MetaSGD(MetaOptimizer):
+ """A canonical Stochastic Gradient Descent optimizer."""
+
+ def __init__(
+ self,
+ net: nn.Module,
+ lr: ScalarOrSchedule,
+ momentum: Union[float, None] = None,
+ nesterov: bool = False,
+ moment_requires_grad: bool = True
+ ):
+ """The `init` function.
+
+ Args:
+ net (nn.Module):
+ A network whose parameters should be optimized.
+ args:
+ Other arguments see `alias.sgd`, here we set `moment_requires_grad=True`
+ to make tensors like momentum be differentiable.
+ """
+
+ super().__init__(
+ net,
+ sgd(
+ lr=lr,
+ momentum=momentum,
+ nesterov=nesterov,
+ moment_requires_grad=moment_requires_grad
+ )
+ )
diff --git a/torchopt/_src/optimizer/rmsprop.py b/torchopt/_src/optimizer/rmsprop.py
new file mode 100644
index 00000000..d1aaf278
--- /dev/null
+++ b/torchopt/_src/optimizer/rmsprop.py
@@ -0,0 +1,58 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Union
+
+from torchopt._src.alias import rmsprop
+from torchopt._src.optimizer.base import Optimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class RMSProp(Optimizer):
+ """An RMSProp optimizer."""
+
+ def __init__(
+ self,
+ params,
+ lr: ScalarOrSchedule,
+ decay: float = 0.9,
+ eps: float = 1e-8,
+ initial_scale: float = 0.,
+ centered: bool = False,
+ momentum: Union[float, None] = None,
+ nesterov: bool = False
+ ):
+ """The `init` function.
+
+ Args:
+ params (iterable):
+ An iterable of `torch.Tensor`s. Specifies what Tensors should be
+ optimized.
+ args:
+ Other arguments see `alias.sgd`.
+ """
+
+ super().__init__(
+ params,
+ rmsprop(
+ lr=lr,
+ decay=decay,
+ eps=eps,
+ initial_scale=initial_scale,
+ centered=centered,
+ momentum=momentum,
+ nesterov=nesterov
+ )
+ )
diff --git a/torchopt/_src/optimizer/sgd.py b/torchopt/_src/optimizer/sgd.py
new file mode 100644
index 00000000..9e3e1c98
--- /dev/null
+++ b/torchopt/_src/optimizer/sgd.py
@@ -0,0 +1,45 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Union
+
+from torchopt._src.alias import sgd
+from torchopt._src.optimizer.base import Optimizer
+from torchopt._src.typing import ScalarOrSchedule
+
+
+class SGD(Optimizer):
+ """The classic SGD optimizer."""
+
+ def __init__(
+ self,
+ params,
+ lr: ScalarOrSchedule,
+ momentum: Union[float, None] = None,
+ nesterov: bool = False
+ ):
+ """The `init` function.
+
+ Args:
+ params (iterable):
+ An iterable of `torch.Tensor`s. Specifies what Tensors should be
+ optimized.
+ args:
+ Other arguments see `alias.adam`.
+ """
+
+ super().__init__(
+ params, sgd(lr=lr, momentum=momentum, nesterov=nesterov, moment_requires_grad=False)
+ )
diff --git a/torchopt/_src/schedule.py b/torchopt/_src/schedule.py
new file mode 100644
index 00000000..864afb69
--- /dev/null
+++ b/torchopt/_src/schedule.py
@@ -0,0 +1,111 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import jax
+import numpy as np
+from absl import logging
+
+from torchopt._src import base, typing
+
+
+def polynomial_schedule(
+ init_value: typing.Scalar,
+ end_value: typing.Scalar,
+ power: typing.Scalar,
+ transition_steps: int,
+ transition_begin: int = 0
+) -> base.Schedule:
+ """Constructs a schedule with polynomial transition from init to end value.
+
+ Args:
+ init_value:
+ Initial value for the scalar to be annealed.
+ end_value:
+ End value of the scalar to be annealed.
+ power:
+ The power of the polynomial used to transition from init to end.
+ transition_steps:
+ Number of steps over which annealing takes place, the scalar starts
+ changing at `transition_begin` steps and completes the transition
+ by `transition_begin + transition_steps` steps.
+ If `transition_steps <= 0`, then the entire annealing process is
+ disabled and the value is held fixed at `init_value`.
+ transition_begin:
+ Must be positive. After how many steps to start annealing (before
+ this many steps the scalar value is held fixed at `init_value`).
+
+ Returns:
+ schedule:
+ A function that maps step counts to values.
+ """
+
+ if transition_steps <= 0:
+ logging.info(
+ 'A polynomial schedule was set with a non-positive `transition_steps` '
+ 'value; this results in a constant schedule with value `init_value`.'
+ )
+ return lambda count: init_value
+
+ if transition_begin < 0:
+ logging.info(
+ 'An exponential schedule was set with a negative `transition_begin` '
+ 'value; this will result in `transition_begin` falling back to `0`.'
+ )
+ transition_begin = 0
+
+ def schedule(count):
+
+ def impl(count):
+ count = np.clip(count - transition_begin, 0, transition_steps)
+ frac = 1 - count / transition_steps
+ return (init_value - end_value) * (frac**power) + end_value
+
+ return jax.tree_map(impl, count)
+
+ return schedule
+
+
+# Alias polynomial schedule to linear schedule for convenience.
+def linear_schedule(
+ init_value: typing.Scalar,
+ end_value: typing.Scalar,
+ transition_steps: int,
+ transition_begin: int = 0
+) -> base.Schedule:
+
+ return polynomial_schedule(
+ init_value=init_value,
+ end_value=end_value,
+ power=1,
+ transition_steps=transition_steps,
+ transition_begin=transition_begin
+ )
diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py
new file mode 100644
index 00000000..7aef0c84
--- /dev/null
+++ b/torchopt/_src/transform.py
@@ -0,0 +1,472 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
+# ==============================================================================
+# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import List, NamedTuple, Tuple, Union
+
+import jax
+import torch
+
+from torchopt._src import base
+from torchopt._src.typing import ScalarOrSchedule, Schedule
+
+ScaleState = base.EmptyState
+
+
+def inc_count(updates, count: Tuple[int]) -> Tuple[int]:
+
+ def f(c, g):
+ return c + 1 if g is not None else c
+
+ return jax.tree_map(f, count, updates)
+
+
+def scale(step_size: float) -> base.GradientTransformation:
+ """Scale updates by some fixed scalar `step_size`.
+
+ Args:
+ step_size:
+ A scalar corresponding to a fixed scaling factor for updates.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ del params
+ return ScaleState()
+
+ def update_fn(updates, state, inplace=True):
+ if inplace:
+
+ def f(g):
+ return g.mul_(step_size) if g is not None else None
+ else:
+
+ def f(g):
+ return g.mul(step_size) if g is not None else None
+
+ updates = jax.tree_map(f, updates)
+ return updates, state
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByScheduleState(NamedTuple):
+ """Maintains count for scale scheduling."""
+
+ count: Tuple[int, ...] # type: ignore
+
+
+def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation:
+ """Scale updates using a custom schedule for the `step_size`.
+
+ Args:
+ step_size_fn:
+ A function that takes an update count as input and proposes the
+ step_size to multiply the updates by.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ return ScaleByScheduleState(count=tuple(0 for _ in range(len(params))))
+
+ def update_fn(updates, state, inplace=True):
+ step_size = step_size_fn(state.count)
+ if inplace:
+ updates = jax.tree_map(lambda g, step_size: g.mul_(step_size), updates, step_size)
+ else:
+ updates = jax.tree_map(lambda g, step_size: g.mul(step_size), updates, step_size)
+ return updates, ScaleByScheduleState(count=inc_count(updates, state.count))
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+def _update_moment(updates, moments, decay, order, inplace=True):
+ """Compute the exponential moving average of the `order`-th moment."""
+
+ if inplace:
+
+ def f(g, t):
+ return t.mul_(decay).add_(g**order, alpha=1 - decay) if g is not None else t
+ else:
+
+ def f(g, t):
+ return t.mul(decay).add(g**order, alpha=1 - decay) if g is not None else t
+
+ return jax.tree_map(f, updates, moments)
+
+
+def _update_moment_per_elem_norm(updates, moments, decay, order, inplace=True):
+ """Compute the EMA of the `order`-th moment of the element-wise norm."""
+
+ if inplace:
+
+ def f(g, t):
+ return t.mul_(decay).add_(g**order, alpha=1 - decay) if g is not None else t
+ else:
+
+ def f(g, t):
+ return t.mul(decay).add(g**order, alpha=1 - decay) if g is not None else t
+
+ return jax.tree_map(f, updates, moments)
+
+
+class ScaleByAdamState(NamedTuple):
+ """State for the Adam algorithm."""
+
+ count: Tuple[int, ...] # type: ignore
+ mu: base.Updates
+ nu: base.Updates
+
+
+def _bias_correction(moment, decay, count, inplace=True):
+ """Perform bias correction. This becomes a no-op as count goes to infinity."""
+
+ if inplace:
+
+ def f(t, c):
+ return t.div_(1 - decay**c)
+ else:
+
+ def f(t, c):
+ return t.div(1 - decay**c)
+
+ return jax.tree_map(f, moment, count)
+
+
+def scale_by_adam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ moment_requires_grad: bool = False,
+) -> base.GradientTransformation:
+ """Rescale updates according to the Adam algorithm.
+
+ References:
+ [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
+
+ Args:
+ b1:
+ Decay rate for the exponentially weighted average of grads.
+ b2:
+ Decay rate for the exponentially weighted average of squared grads.
+ eps:
+ Term added to the denominator to improve numerical stability.
+ eps_root:
+ Term added to the denominator inside the square-root to improve
+ numerical stability when backpropagating gradients through the rescaling.
+ moment_requires_grad:
+ If true, states will be created with flag `requires_grad = True`.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ mu = jax.tree_map( # First moment
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params)
+ nu = jax.tree_map( # Second moment
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params)
+ return ScaleByAdamState(count=tuple(0 for _ in range(len(mu))), mu=tuple(mu), nu=tuple(nu))
+
+ def update_fn(updates, state, inplace=True):
+ mu = _update_moment(updates, state.mu, b1, 1, inplace)
+ nu = _update_moment_per_elem_norm(updates, state.nu, b2, 2, inplace)
+ count_inc = inc_count(updates, state.count)
+ mu_hat = _bias_correction(mu, b1, count_inc, False)
+ nu_hat = _bias_correction(nu, b2, count_inc, False)
+ if inplace:
+
+ def f(g, m, v):
+ return m.div_(torch.sqrt_(v.add_(eps_root)).add_(eps)) if g is not None else None
+ else:
+
+ def f(g, m, v):
+ return m.div(torch.sqrt(v.add(eps_root)).add(eps)) if g is not None else None
+
+ updates = jax.tree_map(f, updates, mu_hat, nu_hat)
+ return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+def scale_by_accelerated_adam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ moment_requires_grad: bool = False,
+) -> base.GradientTransformation:
+ """Rescale updates according to the Adam algorithm.
+
+ This function is accelerated by using some fused accelerated operators.
+
+ References:
+ [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
+
+ Args:
+ b1:
+ Decay rate for the exponentially weighted average of grads.
+ b2:
+ Decay rate for the exponentially weighted average of squared grads.
+ eps:
+ Term added to the denominator to improve numerical stability.
+ eps_root:
+ Term added to the denominator inside the square-root to improve
+ numerical stability when backpropagating gradients through the rescaling.
+ moment_requires_grad:
+ If true, states will be created with flag `requires_grad = True`.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ from .accelerated_op import AdamOp
+
+ def init_fn(params):
+ mu = jax.tree_map( # First moment
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params)
+ nu = jax.tree_map( # Second moment
+ lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+ params)
+ return ScaleByAdamState(count=tuple(0 for _ in range(len(params))), mu=mu, nu=nu)
+
+ def update_fn(updates, state, inplace=True):
+ count_inc = inc_count(updates, state.count)
+ op = AdamOp(b1, b2, eps, eps_root, inplace)
+ out = jax.tree_map(op, state.mu, state.nu, updates, count_inc)
+ new_mus, new_nus, new_updates = [], [], []
+ for new_mu, new_nu, new_update in out:
+ new_mus.append(new_mu)
+ new_nus.append(new_nu)
+ new_updates.append(new_update)
+ return tuple(new_updates), ScaleByAdamState(
+ count=count_inc, mu=tuple(new_mus), nu=tuple(new_nus)
+ )
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+class TraceState(NamedTuple):
+ """Holds an aggregation of past updates."""
+
+ trace: base.Params
+
+
+def trace(
+ decay: float,
+ nesterov: bool = False,
+ moment_requires_grad: bool = False,
+) -> base.GradientTransformation:
+ """Compute a trace of past updates.
+
+ Note: `trace` and `ema` have very similar but distinct updates;
+ `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`.
+ Both are frequently found in the optimisation literature.
+
+ Args:
+ decay:
+ The decay rate for the trace of past updates.
+ nesterov:
+ Whether to use Nesterov momentum.
+ moment_requires_grad:
+ If true, states will be created with flag `requires_grad = True`.
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ if decay == 0.:
+ return TraceState(trace=())
+ else:
+ return TraceState(
+ trace=jax.
+ tree_map(lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params)
+ )
+
+ def update_fn(updates, state, inplace=True):
+ if nesterov:
+ if inplace:
+
+ def f1(g, t):
+ return t.copy_(g.add(t, alpha=decay))
+
+ def f2(g, t):
+ return g.add_(t, alpha=decay)
+
+ new_trace = jax.tree_map(f1, updates, state.trace)
+ updates = jax.tree_map(f2, updates, new_trace)
+ else:
+
+ def f(g, t):
+ return g.add(t, alpha=decay)
+
+ new_trace = jax.tree_map(f, updates, state.trace)
+ updates = jax.tree_map(f, updates, new_trace)
+ else:
+ if inplace:
+
+ def f(g, t):
+ return g.add_(t, alpha=decay)
+
+ updates = jax.tree_map(f, updates, state.trace)
+ state.trace.copy_(updates)
+ new_trace = state.trace
+ else:
+
+ def f(g, t):
+ return g.add(t, alpha=decay)
+
+ updates = jax.tree_map(f, updates, state.trace)
+ new_trace = updates
+
+ return updates, TraceState(trace=new_trace)
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByRmsState(NamedTuple):
+ """State for exponential root mean-squared (RMS)-normalized updates."""
+
+ nu: base.Updates
+
+
+def scale_by_rms(
+ decay: float = 0.9,
+ eps: float = 1e-8,
+ initial_scale: float = 0.
+) -> base.GradientTransformation:
+ """Rescale updates by the root of the exp. moving avg of the square.
+
+ References:
+ [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
+
+ Args:
+ decay:
+ Decay rate for the exponentially weighted average of squared grads.
+ eps:
+ Term added to the denominator to improve numerical stability.
+ initial_scale:
+ Initial value for second moment
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ nu = jax.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment
+ return ScaleByRmsState(nu=nu)
+
+ def update_fn(updates, state, inplace=True):
+ nu = _update_moment_per_elem_norm(updates, state.nu, decay, 2, inplace)
+ if inplace:
+
+ def f(g, n):
+ return g.mul_(torch.rsqrt(n.add(eps)))
+ else:
+
+ def f(g, n):
+ return g.mul(torch.rsqrt(n.add(eps)))
+
+ # """The followings are pytorch style"""
+ # if inplace:
+ # def f(g, n): return g.div_(torch.sqrt_(n).add_(eps))
+ # else:
+ # def f(g, n): return g.div(torch.sqrt(n).add(eps))
+ updates = jax.tree_map(f, updates, nu)
+ return updates, ScaleByRmsState(nu=nu)
+
+ return base.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByRStdDevState(NamedTuple):
+ """State for centered exponential moving average of squares of updates."""
+
+ mu: base.Updates
+ nu: base.Updates
+
+
+def scale_by_stddev(
+ decay: float = 0.9,
+ eps: float = 1e-8,
+ initial_scale: float = 0.
+) -> base.GradientTransformation:
+ """Rescale updates by the root of the centered exp. moving average of squares.
+
+ References:
+ [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
+
+ Args:
+ decay:
+ Decay rate for the exponentially weighted average of squared grads.
+ eps:
+ Term added to the denominator to improve numerical stability.
+ initial_scale:
+ Initial value for second moment
+
+ Returns:
+ An (init_fn, update_fn) tuple.
+ """
+
+ def init_fn(params):
+ mu = jax.tree_map(torch.zeros_like, params) # First moment
+ nu = jax.tree_map(lambda n: torch.full_like(n, initial_scale), params) # second moment
+ return ScaleByRStdDevState(mu=mu, nu=nu)
+
+ def update_fn(updates, state, inplace=True):
+ mu = _update_moment(updates, state.mu, decay, 1, inplace)
+ nu = _update_moment_per_elem_norm(updates, state.nu, decay, 2, inplace)
+ if inplace:
+
+ def f(g, m, n):
+ return g.mul_(torch.rsqrt(n.sub(m**2).add(eps)))
+ else:
+
+ def f(g, m, n):
+ return g.mul(torch.rsqrt(n.sub(m**2).add(eps)))
+
+ # """The followings are pytorch style"""
+ # if inplace:
+ # def f(g, m, n): return g.div_(torch.sqrt_(n.sub_(m ** 2)).add(eps))
+ # else:
+ # def f(g, m, n): return g.div(torch.sqrt(n.sub(m ** 2)).add(eps))
+ updates = jax.tree_map(f, updates, mu, nu)
+ return updates, ScaleByRStdDevState(mu=mu, nu=nu)
+
+ return base.GradientTransformation(init_fn, update_fn)
diff --git a/TorchOpt/_src/pytypes.py b/torchopt/_src/typing.py
similarity index 100%
rename from TorchOpt/_src/pytypes.py
rename to torchopt/_src/typing.py
diff --git a/TorchOpt/_src/update.py b/torchopt/_src/update.py
similarity index 54%
rename from TorchOpt/_src/update.py
rename to torchopt/_src/update.py
index a77adf7e..2d17adb7 100644
--- a/TorchOpt/_src/update.py
+++ b/torchopt/_src/update.py
@@ -32,41 +32,42 @@
import jax
-from TorchOpt._src import base
+from torchopt._src import base
-def apply_updates(
- params: base.Params,
- updates: base.Updates,
- inplace: bool = True
-) -> base.Params:
- """Applies an update to the corresponding parameters.
+def apply_updates(params: base.Params, updates: base.Updates, inplace: bool = True) -> base.Params:
+ """Applies an update to the corresponding parameters.
- This is a utility functions that applies an update to a set of parameters, and
- then returns the updated parameters to the caller. As an example, the update
- may be a gradient transformed by a sequence of`GradientTransformations`. This
- function is exposed for convenience, but it just adds updates and parameters;
- you may also apply updates to parameters manually, using `tree_map`
- (e.g. if you want to manipulate updates in custom ways before applying them).
+ This is a utility functions that applies an update to a set of parameters,
+ and then returns the updated parameters to the caller. As an example, the
+ update may be a gradient transformed by a sequence of`GradientTransformations`.
+ This function is exposed for convenience, but it just adds updates and parameters;
+ you may also apply updates to parameters manually, using `tree_map` (e.g. if
+ you want to manipulate updates in custom ways before applying them).
- Args:
- params: a tree of parameters.
- updates: a tree of updates, the tree structure and the shape of the leaf
- nodes must match that of `params`.
- inplace: if True, will update params in a inplace manner.
+ Args:
+ params:
+ A tree of parameters.
+ updates:
+ A tree of updates, the tree structure and the shape of the leaf
+ nodes must match that of `params`.
+ inplace:
+ If true, will update params in a inplace manner.
- Returns:
- Updated parameters, with same structure, shape and type as `params`.
- """
- if inplace:
+ Returns:
+ Updated parameters, with same structure, shape and type as `params`.
+ """
- def f(p, u):
- if u is not None:
- p.data.add_(u)
- return p
- else:
+ if inplace:
- def f(p, u):
- return p.add(u) if u is not None else p
+ def f(p, u):
+ if u is not None:
+ p.data.add_(u)
+ return p
- return jax.tree_map(f, params, updates)
+ else:
+
+ def f(p, u):
+ return p.add(u) if u is not None else p
+
+ return jax.tree_map(f, params, updates)
diff --git a/torchopt/_src/utils.py b/torchopt/_src/utils.py
new file mode 100644
index 00000000..79921916
--- /dev/null
+++ b/torchopt/_src/utils.py
@@ -0,0 +1,197 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Dict, List, NamedTuple, Union
+
+import jax
+import torch
+import torch.nn as nn
+
+from torchopt._src.optimizer.meta import MetaOptimizer
+
+
+class _ModuleState(NamedTuple):
+ params: List[Dict]
+ visual_contents: Union[None, Dict] = None
+
+
+# mypy: ignore-errors
+def stop_gradient(target):
+ """Stop the gradient for the input object.
+
+ Since a tensor use `grad_fn` to connect itself with the previous computation
+ graph, the back-propagated gradient will flow over the tensor and continue
+ flow to the tensors that is connected by `grad_fn`. Some algorithms requires
+ manually detaching tensors from the computation graph.
+
+ Note that the stop_gradient operation is in-place.
+
+ Args:
+ target:
+ The target that to be detached from the computation graph, it could
+ be a `nn.Module`, `torchopt.MetaOptimizer`, state of the
+ `torchopt.MetaOptimizer`, or just a plain list of tensors.
+ inplace:
+ If true, the target will be detached in-place. if false, this function
+ will return a detached copy of the target. The in-place operation is
+ fast and memory efficient but may raise back-propagation error.
+ """
+
+ def f(obj):
+ if isinstance(obj, torch.Tensor):
+ requires_grad = obj.requires_grad
+ obj.detach_().requires_grad_(requires_grad)
+ return None
+
+ if isinstance(target, _ModuleState):
+ true_target = target.params
+ elif isinstance(target, nn.Module):
+ true_target = tuple(target.parameters())
+ elif isinstance(target, MetaOptimizer):
+ true_target, _ = jax.tree_flatten(target.state_dict())
+ else:
+ true_target = target
+
+ jax.tree_map(f, true_target)
+
+
+def extract_state_dict(mod, copy=False, *, with_buffer=True, enable_visual=False, visual_prefix=''):
+ """Extract target state.
+
+ Since a tensor use `grad_fn` to connect itself with the previous computation
+ graph, the back-propagated gradient will flow over the tensor and continue
+ flow to the tensors that is connected by `grad_fn`. Some algorithms requires
+ manually detaching tensors from the computation graph.
+
+ Note that the extracted state is a reference, which means any in-place operator
+ will affect the target that the state is extracted from.
+
+ Args:
+ mod:
+ It could be a `nn.Module` or `torchopt.MetaOptimizer`.
+ with_buffer:
+ Extract buffer together with parameters, this argument is only used
+ if the input target is `nn.Module`.
+ enable_visual:
+ Add additional annotations, which could be used in computation graph
+ visualization. Currently, this flag only has effect on `nn.Module` but
+ we will support `torchopt.MetaOptimizer` later.
+ visual_prefix:
+ Prefix for the visualization annotations.
+
+ Returns:
+ State extracted of the input object.
+ """
+
+ if isinstance(mod, nn.Module):
+ if enable_visual:
+ visual_contents = {}
+
+ for k, v in mod.named_parameters():
+ if v.grad_fn is not None:
+ visual_contents.update({v.grad_fn: (visual_prefix + k, v)})
+ else:
+ visual_contents.update({v: visual_prefix + k})
+ else:
+ visual_contents = None
+
+ params = []
+
+ def get_v(v):
+ if copy:
+ requires_grad = v.requires_grad
+ return v.clone().detach_().requires_grad_(requires_grad)
+ else:
+ return v
+
+ def _update(term):
+ if len(term) != 0:
+ params.append({k: get_v(v) for k, v in term.items()})
+
+ _update(mod._parameters)
+ if with_buffer:
+ _update(mod._buffers)
+ for module in mod.modules():
+ if module is mod:
+ continue
+ _update(module._parameters)
+ if with_buffer:
+ _update(module._buffers)
+ return _ModuleState(params=tuple(params), visual_contents=visual_contents)
+ elif isinstance(mod, MetaOptimizer):
+ state = mod.state_dict()
+ if copy:
+ flatten_state, state_tree = jax.tree_flatten(state)
+
+ def get_v(v):
+ if not isinstance(v, torch.Tensor):
+ return v
+ requires_grad = v.requires_grad
+ return v.clone().detach_().requires_grad_(requires_grad)
+
+ flatten_state = jax.tree_map(get_v, flatten_state)
+ return state_tree.unflatten(flatten_state)
+ else:
+ return state
+
+ else:
+ raise RuntimeError(f"Unexpected class of {mod}")
+
+
+def _extract_container(mod, with_buffer=True):
+ if isinstance(mod, nn.Module):
+ containers = []
+
+ def _update(term):
+ if len(term) != 0:
+ containers.append(term)
+
+ _update(mod._parameters)
+ if with_buffer:
+ _update(mod._buffers)
+ for module in mod.modules():
+ if module is mod:
+ continue
+ _update(module._parameters)
+ if with_buffer:
+ _update(module._buffers)
+ return tuple(containers)
+ else:
+ raise RuntimeError(f"Unexpected class of {mod}")
+
+
+def recover_state_dict(mod, state):
+ """Recover state.
+
+ This function is compatible for the `extract_state`.
+
+ Note that the recovering process is not in-place, so the tensors of the object
+ will not be modified.
+
+ Args:
+ mod:
+ Target that need to recover.
+ state:
+ The recovering state.
+ """
+
+ if isinstance(mod, nn.Module):
+ target_container = _extract_container(mod)
+ for target, source in zip(target_container, state.params):
+ target.update(source)
+ elif isinstance(mod, MetaOptimizer):
+ mod.load_state_dict(state)
+ else:
+ raise RuntimeError(f"Unexpected class of {mod}")
diff --git a/torchopt/_src/visual.py b/torchopt/_src/visual.py
new file mode 100644
index 00000000..696a1f77
--- /dev/null
+++ b/torchopt/_src/visual.py
@@ -0,0 +1,238 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# This file is modified from:
+# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
+# ==============================================================================
+
+import warnings
+from collections import namedtuple
+from distutils.version import LooseVersion
+from typing import Dict, Generator
+
+import torch
+from graphviz import Digraph
+
+Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))
+
+# Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
+SAVED_PREFIX = "_saved_"
+
+
+def get_fn_name(fn, show_attrs, max_attr_chars):
+ name = str(type(fn).__name__)
+ if not show_attrs:
+ return name
+ attrs = dict()
+ for attr in dir(fn):
+ if not attr.startswith(SAVED_PREFIX):
+ continue
+ val = getattr(fn, attr)
+ attr = attr[len(SAVED_PREFIX):]
+ if torch.is_tensor(val):
+ attrs[attr] = "[saved tensor]"
+ elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val):
+ attrs[attr] = "[saved tensors]"
+ else:
+ attrs[attr] = str(val)
+ if not attrs:
+ return name
+ max_attr_chars = max(max_attr_chars, 3)
+ col1width = max(len(k) for k in attrs.keys())
+ col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars)
+ sep = "-" * max(col1width + col2width + 2, len(name))
+ attrstr = '%-' + str(col1width) + 's: %' + str(col2width) + 's'
+
+ def truncate(s):
+ return s[:col2width - 3] + "..." if len(s) > col2width else s
+
+ params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
+ return name + '\n' + sep + '\n' + params
+
+
+# mypy: ignore-errors
+def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
+ """Produces Graphviz representation of PyTorch autograd graph.
+
+ If a node represents a backward function, it is gray. Otherwise, the node
+ represents a tensor and is either blue, orange, or green:
+ - Blue: reachable leaf tensors that requires grad (tensors whose `.grad`
+ fields will be populated during `.backward()`)
+ - Orange: saved tensors of custom autograd functions as well as those
+ saved by built-in backward nodes
+ - Green: tensor passed in as outputs
+ - Dark green: if any output is a view, we represent its base tensor with
+ a dark green node.
+
+ Args:
+ var:
+ Output tensor.
+ params: ([dict of (name, tensor) or state_dict])
+ Parameters to add names to node that requires grad.
+ show_attrs:
+ Whether to display non-tensor attributes of backward nodes
+ (Requires PyTorch version >= 1.9)
+ show_saved:
+ Whether to display saved tensor nodes that are not by custom
+ autograd functions. Saved tensor nodes for custom functions, if
+ present, are always displayed. (Requires PyTorch version >= 1.9)
+ max_attr_chars:
+ If show_attrs is `True`, sets max number of characters
+ to display for any given attribute.
+ """
+
+ if LooseVersion(torch.__version__) < LooseVersion("1.9") and \
+ (show_attrs or show_saved):
+ warnings.warn(
+ "make_dot: showing grad_fn attributes and saved variables"
+ " requires PyTorch version >= 1.9. (This does NOT apply to"
+ " saved tensors saved by custom autograd functions.)"
+ )
+
+ param_map = {}
+
+ if params is not None:
+ from torchopt._src.utils import _ModuleState
+
+ if isinstance(params, _ModuleState):
+ param_map.update(params.visual_contents)
+ elif isinstance(params, Dict):
+ param_map.update({v: k for k, v in params.items()})
+ elif isinstance(params, Generator):
+ param_map.update({v: k for k, v in params})
+ else:
+ for param in params:
+ if isinstance(param, _ModuleState):
+ param_map.update(param.visual_contents)
+ elif isinstance(param, Generator):
+ param_map.update({v: k for k, v in param})
+ else:
+ param_map.update({v: k for k, v in param.items()})
+
+ node_attr = dict(
+ style='filled',
+ shape='box',
+ align='left',
+ fontsize='10',
+ ranksep='0.1',
+ height='0.2',
+ fontname='monospace'
+ )
+ dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
+ seen = set()
+
+ def size_to_str(size):
+ return '(' + (', ').join(['%d' % v for v in size]) + ')'
+
+ def get_var_name(var, name=None):
+ if not name:
+ name = param_map[var] if var in param_map else ''
+ return '%s\n %s' % (name, size_to_str(var.size()))
+
+ def get_var_name_with_flag(var):
+ if var in param_map:
+ return '%s\n %s' % (param_map[var][0], size_to_str(param_map[var][1].size()))
+ else:
+ return None
+
+ def add_nodes(fn):
+ assert not torch.is_tensor(fn)
+ if fn in seen:
+ return
+ seen.add(fn)
+
+ if show_saved:
+ for attr in dir(fn):
+ if not attr.startswith(SAVED_PREFIX):
+ continue
+ val = getattr(fn, attr)
+ seen.add(val)
+ attr = attr[len(SAVED_PREFIX):]
+ if torch.is_tensor(val):
+ dot.edge(str(id(fn)), str(id(val)), dir="none")
+ dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange')
+ if isinstance(val, tuple):
+ for i, t in enumerate(val):
+ if torch.is_tensor(t):
+ name = attr + '[%s]' % str(i)
+ dot.edge(str(id(fn)), str(id(t)), dir="none")
+ dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange')
+
+ if hasattr(fn, 'variable'):
+ # if grad_accumulator, add the node for `.variable`
+ var = fn.variable
+ seen.add(var)
+ dot.node(str(id(var)), get_var_name(var), fillcolor='lightblue')
+ dot.edge(str(id(var)), str(id(fn)))
+
+ fn_name = get_fn_name(fn, show_attrs, max_attr_chars)
+ fn_fillcolor = None
+ var_name = get_var_name_with_flag(fn)
+ if var_name is not None:
+ fn_name = '%s\n %s' % (fn_name, var_name)
+ fn_fillcolor = 'lightblue'
+
+ # add the node for this grad_fn
+ dot.node(str(id(fn)), fn_name, fillcolor=fn_fillcolor)
+
+ # recurse
+ if hasattr(fn, 'next_functions'):
+ for u in fn.next_functions:
+ if u[0] is not None:
+ dot.edge(str(id(u[0])), str(id(fn)))
+ add_nodes(u[0])
+
+ # note: this used to show .saved_tensors in pytorch0.2, but stopped
+ # working* as it was moved to ATen and Variable-Tensor merged
+ # also note that this still works for custom autograd functions
+ if hasattr(fn, 'saved_tensors'):
+ for t in fn.saved_tensors:
+ dot.edge(str(id(t)), str(id(fn)))
+ dot.node(str(id(t)), get_var_name(t), fillcolor='orange')
+
+ def add_base_tensor(var, color='darkolivegreen1'):
+ if var in seen:
+ return
+ seen.add(var)
+ dot.node(str(id(var)), get_var_name(var), fillcolor=color)
+ if (var.grad_fn):
+ add_nodes(var.grad_fn)
+ dot.edge(str(id(var.grad_fn)), str(id(var)))
+ if var._is_view():
+ add_base_tensor(var._base, color='darkolivegreen3')
+ dot.edge(str(id(var._base)), str(id(var)), style="dotted")
+
+ # handle multiple outputs
+ if isinstance(var, tuple):
+ for v in var:
+ add_base_tensor(v)
+ else:
+ add_base_tensor(var)
+
+ resize_graph(dot)
+
+ return dot
+
+
+def resize_graph(dot, size_per_element=0.15, min_size=12):
+ """Resize the graph according to how much content it contains.
+ Modify the graph in place.
+ """
+
+ # Get the approximate number of nodes and edges
+ num_rows = len(dot.body)
+ content_size = num_rows * size_per_element
+ size = max(min_size, content_size)
+ size_str = str(size) + "," + str(size)
+ dot.graph_attr.update(size=size_str)
diff --git a/tutorials/1_Functional_Optimizer.ipynb b/tutorials/1_Functional_Optimizer.ipynb
old mode 100755
new mode 100644
index 868bb00d..2dff7be4
--- a/tutorials/1_Functional_Optimizer.ipynb
+++ b/tutorials/1_Functional_Optimizer.ipynb
@@ -37,12 +37,12 @@
"import torch\n",
"import functorch\n",
"import torch.autograd\n",
- "from torch import nn\n",
+ "import torch.nn as nn\n",
"import optax\n",
"import jax\n",
"from jax import numpy as jnp\n",
"\n",
- "import TorchOpt\n",
+ "import torchopt\n",
"\n",
"\n",
"class Net(nn.Module):\n",
@@ -138,7 +138,7 @@
" func, params = functorch.make_functional(net)\n",
"\n",
" lr = 1.\n",
- " optimizer = TorchOpt.adam(lr)\n",
+ " optimizer = torchopt.adam(lr)\n",
"\n",
" opt_state = optimizer.init(params)\n",
"\n",
@@ -150,7 +150,7 @@
" grad = torch.autograd.grad(loss, params)\n",
" updates, opt_state = optimizer.update(grad, opt_state)\n",
" print(params)\n",
- " params = TorchOpt.apply_updates(params, updates)\n",
+ " params = torchopt.apply_updates(params, updates)\n",
" print(params)"
]
},
@@ -181,7 +181,7 @@
"- Full TorchOpt\n",
"\n",
"The Third example is to illustrate that TorchOpt can also directly replace torch.optim with exactly the same usage. Note the API \n",
- "difference happens between TorchOpt.adam() and TorchOpt.Adam(). "
+ "difference happens between torchopt.adam() and torchopt.Adam(). "
]
},
{
@@ -196,7 +196,7 @@
" net = Net(dim)\n",
"\n",
" lr = 1.\n",
- " optim = TorchOpt.Adam(net.parameters(), lr=lr)\n",
+ " optim = torchopt.Adam(net.parameters(), lr=lr)\n",
"\n",
" xs = 2 * torch.ones(batch_size, dim)\n",
" ys = torch.ones(batch_size)\n",
@@ -294,7 +294,7 @@
"## 2. Differentiable Optimization with functional optimizor\n",
"Coupled with functional optimizer, you can conduct differentiable optimization by setting the inplce flag as False in update and apply_updates function. (which might be helpful for meta-learning algorithm implementation with functional programing style). \n",
"\n",
- "Note that TorchOpt.SGD, TorchOpt.Adam do not support differentiable optimization. Refer to the Meta Optimizer notebook for pytorch-like differentiable optimizers."
+ "Note that torchopt.SGD, torchopt.Adam do not support differentiable optimization. Refer to the Meta Optimizer notebook for pytorch-like differentiable optimizers."
]
},
{
@@ -311,7 +311,7 @@
"\n",
" lr = 1.\n",
" # sgd example\n",
- " optimizer = TorchOpt.sgd(lr)\n",
+ " optimizer = torchopt.sgd(lr)\n",
" meta_param = torch.tensor(1., requires_grad=True)\n",
"\n",
" opt_state = optimizer.init(params)\n",
@@ -325,7 +325,7 @@
" loss = ((pred - ys) ** 2).sum()\n",
" grad = torch.autograd.grad(loss, params, create_graph=True)\n",
" updates, opt_state = optimizer.update(grad, opt_state, inplace=False)\n",
- " params = TorchOpt.apply_updates(params, updates, inplace=False)\n",
+ " params = torchopt.apply_updates(params, updates, inplace=False)\n",
"\n",
" pred = func(params, xs)\n",
" loss = ((pred - ys) ** 2).sum()\n",
@@ -365,7 +365,7 @@
"metadata": {},
"outputs": [],
"source": [
- "optim = TorchOpt.adam(lr=1., moment_requires_grad=False)"
+ "optim = torchopt.adam(lr=1., moment_requires_grad=False)"
]
},
{
@@ -374,7 +374,7 @@
"metadata": {},
"outputs": [],
"source": [
- "optim = TorchOpt.adam(lr=1., moment_requires_grad=True)"
+ "optim = torchopt.adam(lr=1., moment_requires_grad=True)"
]
},
{
@@ -383,7 +383,7 @@
"metadata": {},
"outputs": [],
"source": [
- "optim = TorchOpt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)"
+ "optim = torchopt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)"
]
},
{
@@ -418,7 +418,7 @@
}
],
"source": [
- "TorchOpt.accelerated_op_available(torch.device(\"cpu\"))"
+ "torchopt.accelerated_op_available(torch.device(\"cpu\"))"
]
},
{
@@ -438,7 +438,7 @@
}
],
"source": [
- "TorchOpt.accelerated_op_available(torch.device(\"cuda\"))"
+ "torchopt.accelerated_op_available(torch.device(\"cuda\"))"
]
},
{
@@ -448,7 +448,7 @@
"outputs": [],
"source": [
"net = Net(1).cuda()\n",
- "optim = TorchOpt.Adam(net.parameters(), lr=1., use_accelerated_op=True)"
+ "optim = torchopt.Adam(net.parameters(), lr=1., use_accelerated_op=True)"
]
},
{
@@ -457,7 +457,7 @@
"metadata": {},
"outputs": [],
"source": [
- "optim = TorchOpt.adam(lr=1., use_accelerated_op=True)"
+ "optim = torchopt.adam(lr=1., use_accelerated_op=True)"
]
}
],
diff --git a/tutorials/2_Visualization.ipynb b/tutorials/2_Visualization.ipynb
index f1ce0aa6..c8593b94 100644
--- a/tutorials/2_Visualization.ipynb
+++ b/tutorials/2_Visualization.ipynb
@@ -98,12 +98,12 @@
],
"source": [
"import torch\n",
- "import TorchOpt\n",
+ "import torchopt\n",
"\n",
"\n",
"x = torch.tensor(1., requires_grad=True)\n",
"y = 2 * x\n",
- "TorchOpt.visual.make_dot(y, params={'x': x, 'y': y})"
+ "torchopt.visual.make_dot(y, params={'x': x, 'y': y})"
]
},
{
@@ -245,8 +245,8 @@
}
],
"source": [
- "from torch import nn\n",
- "from torch.nn import functional as F\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
"\n",
"\n",
"class Net(nn.Module):\n",
@@ -264,7 +264,7 @@
"xs = torch.ones(batch_size, dim)\n",
"pred = net(xs)\n",
"loss = F.mse_loss(pred, torch.ones_like(pred))\n",
- "TorchOpt.visual.make_dot(loss, params=(net.named_parameters(), {\"loss\": loss}))"
+ "torchopt.visual.make_dot(loss, params=(net.named_parameters(), {\"loss\": loss}))"
]
},
{
@@ -317,7 +317,7 @@
"dim = 5\n",
"batch_size = 2\n",
"net = MetaNet(dim).cuda()\n",
- "optimizer = TorchOpt.MetaSGD(net, lr=1e-3)\n",
+ "optimizer = torchopt.MetaSGD(net, lr=1e-3)\n",
"meta_param = torch.tensor(1., requires_grad=True)\n",
"\n",
"xs = torch.ones(batch_size, dim).cuda()\n",
@@ -325,17 +325,17 @@
"pred = net(xs, meta_param)\n",
"loss = F.mse_loss(pred, torch.ones_like(pred))\n",
"# set enable_visual\n",
- "net_state_0 = TorchOpt.extract_state_dict(\n",
+ "net_state_0 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step0.')\n",
"optimizer.step(loss)\n",
"# set enable_visual\n",
- "net_state_1 = TorchOpt.extract_state_dict(\n",
+ "net_state_1 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.')\n",
"\n",
"pred = net(xs, meta_param)\n",
"loss = F.mse_loss(pred, torch.ones_like(pred))\n",
"# draw computation graph\n",
- "TorchOpt.visual.make_dot(loss,\n",
+ "torchopt.visual.make_dot(loss,\n",
" [net_state_0, net_state_1,\n",
" {\"meta_param\": meta_param, 'loss': loss}]\n",
" ).render(\"meta_graph\", format=\"png\")\n",
diff --git a/tutorials/3_Meta_Optimizer.ipynb b/tutorials/3_Meta_Optimizer.ipynb
index b76114f4..a846c81c 100644
--- a/tutorials/3_Meta_Optimizer.ipynb
+++ b/tutorials/3_Meta_Optimizer.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# TorchOpt as MetaOptimizer"
+ "# torchopt as Meta-Optimizer"
]
},
{
@@ -20,7 +20,7 @@
"source": [
"## 1. Basic API for differentiable optimizer\n",
"\n",
- "`MetaOptimizer` is the main class for our differnetiabl optimzier. Combined with the functional optimizer `TorchOpt.sgd` and `TorchOpt.adam` mentioned in the tutorial 1, we can define our high-level API `TorchOpt.MetaSGD` and `TorchOpt.MetaAdam`. We will discuss how this combination happens with `TorchOpt.chain` in Section 3. Let us consider the problem below."
+ "`MetaOptimizer` is the main class for our differnetiabl optimzier. Combined with the functional optimizer `torchopt.sgd` and `torchopt.adam` mentioned in the tutorial 1, we can define our high-level API `torchopt.MetaSGD` and `torchopt.MetaAdam`. We will discuss how this combination happens with `torchopt.chain` in Section 3. Let us consider the problem below."
]
},
{
@@ -56,7 +56,7 @@
"outputs": [],
"source": [
"import torch\n",
- "from torch import nn\n",
+ "import torch.nn as nn\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
@@ -105,9 +105,9 @@
}
],
"source": [
- "import TorchOpt\n",
+ "import torchopt\n",
"\n",
- "optim = TorchOpt.MetaSGD(net, lr=1.)\n",
+ "optim = torchopt.MetaSGD(net, lr=1.)\n",
"inner_loss = net(x)\n",
"optim.step(inner_loss)\n",
"outer_loss = net(x)\n",
@@ -160,21 +160,21 @@
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib import image as imgplt\n",
- "from torch.nn import functional as F\n",
+ "import torch.nn.functional as F\n",
"\n",
"net = Net()\n",
"x = torch.tensor(2., requires_grad=True)\n",
"y = torch.tensor(1.)\n",
"\n",
- "optim = TorchOpt.MetaAdam(net, lr=1., moment_requires_grad=False)\n",
+ "optim = torchopt.MetaAdam(net, lr=1., moment_requires_grad=False)\n",
"inner_loss = F.mse_loss(net(x), y)\n",
- "net_state_0 = TorchOpt.extract_state_dict(\n",
+ "net_state_0 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step0.')\n",
"optim.step(inner_loss)\n",
- "net_state_1 = TorchOpt.extract_state_dict(\n",
+ "net_state_1 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.')\n",
"outer_loss = F.mse_loss(net(x), y)\n",
- "TorchOpt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1,{'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1,{'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
"plt.figure(figsize=(15,15))\n",
"plt.imshow(imgplt.imread('graph.png'))"
]
@@ -219,16 +219,16 @@
"x = torch.tensor(2., requires_grad=True)\n",
"y = torch.tensor(1.)\n",
"\n",
- "optim = TorchOpt.MetaAdam(net, lr=1.)\n",
+ "optim = torchopt.MetaAdam(net, lr=1.)\n",
"inner_loss = F.mse_loss(net(x), y)\n",
- "net_state_0 = TorchOpt.extract_state_dict(\n",
+ "net_state_0 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step0.')\n",
"optim.step(inner_loss)\n",
- "net_state_1 = TorchOpt.extract_state_dict(\n",
+ "net_state_1 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.')\n",
"\n",
"outer_loss = F.mse_loss(net(x), y)\n",
- "TorchOpt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
"plt.figure(figsize=(15,15))\n",
"plt.imshow(imgplt.imread('graph.png'))"
]
@@ -255,7 +255,7 @@
"\n",
"We observe that how to reinitialize the inner-loop parameter in a new bi-level process vary in different Meta-Learning algorithms. For instance, in algorithm like MAML, every time a new task comes, we need to reset the parameters to the initial ones. In other cases such as Meta-gradient reinforcement learning, the inner-loop network parameter just inherit previous updated parameter to continue the new bi-level process.\n",
"\n",
- "We provide the `TorchOpt.extract_state_dict` and `TorchOpt.recover_state_dict` function to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `copy=True` to extract the copy of state dictionary."
+ "We provide the `torchopt.extract_state_dict` and `torchopt.recover_state_dict` function to extract and restore the state of network and optimizer. By default, the extracted state dictionary is a reference (this design is for accumulating gradient of multi-task batch training, MAML for example). You can also set `copy=True` to extract the copy of state dictionary."
]
},
{
@@ -275,13 +275,13 @@
"source": [
"net = Net()\n",
"x = torch.tensor(2., requires_grad=True)\n",
- "optim = TorchOpt.MetaAdam(net, lr=1.)\n",
- "init_net_state = TorchOpt.extract_state_dict(net)\n",
- "init_optim_state = TorchOpt.extract_state_dict(optim)\n",
+ "optim = torchopt.MetaAdam(net, lr=1.)\n",
+ "init_net_state = torchopt.extract_state_dict(net)\n",
+ "init_optim_state = torchopt.extract_state_dict(optim)\n",
"\n",
"# get the copy of state dictionary\n",
- "init_net_state_copy = TorchOpt.extract_state_dict(net, copy=True)\n",
- "init_optim_state_copy = TorchOpt.extract_state_dict(optim, copy=True)\n",
+ "init_net_state_copy = torchopt.extract_state_dict(net, copy=True)\n",
+ "init_optim_state_copy = torchopt.extract_state_dict(optim, copy=True)\n",
"\n",
"# Conduct 2 inner-loop optimization \n",
"inner_loss = net(x)\n",
@@ -291,8 +291,8 @@
"print(net.a)\n",
"\n",
"# Recover and reconduct 2 inner-loop optimization \n",
- "TorchOpt.recover_state_dict(net, init_net_state)\n",
- "TorchOpt.recover_state_dict(optim, init_optim_state)\n",
+ "torchopt.recover_state_dict(net, init_net_state)\n",
+ "torchopt.recover_state_dict(optim, init_optim_state)\n",
"inner_loss = net(x)\n",
"optim.step(inner_loss)\n",
"inner_loss = net(x)\n",
@@ -352,14 +352,14 @@
"\n",
"net = Net2Tasks()\n",
"x = torch.tensor(2., requires_grad=True)\n",
- "optim = TorchOpt.MetaSGD(net, lr=1.)"
+ "optim = torchopt.MetaSGD(net, lr=1.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Once we call `step` method of `MetaOptimizer`, the parameters of the network would be changed. We should use `TorchOpt.extract_state_dict` to extract state and use `TorchOpt.recover_state_dict` to recover the state. Note that if we use optimizers that have momentum buffers, we should also extract and recover them, vanilla SGD does not have momentum buffers so codes `init_optim_state = TorchOpt.extract_state_dict(optim)` and `TorchOpt.recover_state_dict(optim, init_optim_state)` have no effect."
+ "Once we call `step` method of `MetaOptimizer`, the parameters of the network would be changed. We should use `torchopt.extract_state_dict` to extract state and use `torchopt.recover_state_dict` to recover the state. Note that if we use optimizers that have momentum buffers, we should also extract and recover them, vanilla SGD does not have momentum buffers so codes `init_optim_state = torchopt.extract_state_dict(optim)` and `torchopt.recover_state_dict(optim, init_optim_state)` have no effect."
]
},
{
@@ -378,8 +378,8 @@
}
],
"source": [
- "init_net_state = TorchOpt.extract_state_dict(net)\n",
- "init_optim_state = TorchOpt.extract_state_dict(optim)\n",
+ "init_net_state = torchopt.extract_state_dict(net)\n",
+ "init_optim_state = torchopt.extract_state_dict(optim)\n",
"# it's SGD so state_dict is empty\n",
"print(init_optim_state)\n",
"\n",
@@ -389,8 +389,8 @@
"lo1.backward()\n",
"print(x.grad)\n",
"\n",
- "TorchOpt.recover_state_dict(net, init_net_state)\n",
- "TorchOpt.recover_state_dict(optim, init_optim_state)\n",
+ "torchopt.recover_state_dict(net, init_net_state)\n",
+ "torchopt.recover_state_dict(optim, init_optim_state)\n",
"li2 = net.task2(x)\n",
"optim.step(li2)\n",
"lo2 = net.task2(x)\n",
@@ -451,8 +451,8 @@
"net = Net()\n",
"x = torch.tensor(2., requires_grad=True)\n",
"\n",
- "impl = TorchOpt.combine.chain(TorchOpt.clip.clip_grad_norm(max_norm=2.), TorchOpt.sgd(lr=1., moment_requires_grad=True))\n",
- "optim = TorchOpt.MetaOptimizer(net, impl)\n",
+ "impl = torchopt.combine.chain(torchopt.clip.clip_grad_norm(max_norm=2.), torchopt.sgd(lr=1., moment_requires_grad=True))\n",
+ "optim = torchopt.MetaOptimizer(net, impl)\n",
"li = net(x)\n",
"optim.step(li)\n",
"lo = net(x)\n",
@@ -496,7 +496,7 @@
}
],
"source": [
- "TorchOpt.accelerated_op_available(torch.device(\"cpu\"))"
+ "torchopt.accelerated_op_available(torch.device(\"cpu\"))"
]
},
{
@@ -516,7 +516,7 @@
}
],
"source": [
- "TorchOpt.accelerated_op_available(torch.device(\"cuda\"))"
+ "torchopt.accelerated_op_available(torch.device(\"cuda\"))"
]
},
{
@@ -552,16 +552,16 @@
"x = torch.tensor(2., requires_grad=True, device=torch.device(\"cuda\"))\n",
"y = torch.tensor(1., device=torch.device(\"cuda\"))\n",
"\n",
- "optim = TorchOpt.MetaAdam(net, lr=1., use_accelerated_op=True)\n",
+ "optim = torchopt.MetaAdam(net, lr=1., use_accelerated_op=True)\n",
"\n",
"inner_loss = F.mse_loss(net(x), y)\n",
- "net_state_0 = TorchOpt.extract_state_dict(\n",
+ "net_state_0 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step0.')\n",
"optim.step(inner_loss)\n",
- "net_state_1 = TorchOpt.extract_state_dict(\n",
+ "net_state_1 = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.')\n",
"outer_loss = F.mse_loss(net(x), y)\n",
- "TorchOpt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1,{'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss, params=[net_state_0, net_state_1,{'x': x, 'outer_loss': outer_loss}]).render(\"graph\", format=\"png\")\n",
"plt.figure(figsize=(15,15))\n",
"plt.imshow(imgplt.imread('graph.png'))"
]
diff --git a/tutorials/4_Stop_Gradient.ipynb b/tutorials/4_Stop_Gradient.ipynb
old mode 100755
new mode 100644
index 4c13f420..21492fc5
--- a/tutorials/4_Stop_Gradient.ipynb
+++ b/tutorials/4_Stop_Gradient.ipynb
@@ -4,14 +4,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# TorchOpt.stop_gradient in meta learning"
+ "# `torchopt.stop_gradient` in Meta-Learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "In this tutoial, we will illustrate the usage of TorchOpt.stop_gradient with a meta-learning example. We use TorchOpt.visual to help us visualize what is going on in automatic differentiation. Firstly, we define a simple network and the objective function for inner, outer optimization."
+ "In this tutoial, we will illustrate the usage of torchopt.stop_gradient with a meta-learning example. We use torchopt.visual to help us visualize what is going on in automatic differentiation. Firstly, we define a simple network and the objective function for inner, outer optimization."
]
},
{
@@ -21,8 +21,8 @@
"outputs": [],
"source": [
"import torch\n",
- "from torch import nn\n",
- "from torch.nn import functional as F\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
@@ -69,8 +69,8 @@
"metadata": {},
"outputs": [],
"source": [
- "import TorchOpt\n",
- "from TorchOpt import MetaSGD\n",
+ "import torchopt\n",
+ "from torchopt import MetaSGD\n",
"from matplotlib import image as imgplt\n",
"from matplotlib import pyplot as plt\n",
"\n",
@@ -125,7 +125,7 @@
"# inner loss\n",
"loss = loss_fn(net(x), y)\n",
"print(f\"inner loss: {loss:.4f}\")\n",
- "TorchOpt.visual.make_dot(loss).render(\"full_graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(loss).render(\"full_graph\", format=\"png\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(imgplt.imread('full_graph.png'))"
]
@@ -195,12 +195,12 @@
],
"source": [
"# extract state_dict for updated network\n",
- "one_step_net_state = TorchOpt.extract_state_dict(net)\n",
- "one_step_optim_state = TorchOpt.extract_state_dict(optim)\n",
+ "one_step_net_state = torchopt.extract_state_dict(net)\n",
+ "one_step_optim_state = torchopt.extract_state_dict(optim)\n",
"# calculate outer loss\n",
"outer_loss = loss_fn(net(x), y)\n",
"print(f\"outer loss: {outer_loss:.4f}\")\n",
- "TorchOpt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(imgplt.imread('full_graph.png'))"
]
@@ -294,7 +294,7 @@
"loss = inner_loss * meta_parameter\n",
"optim.step(loss)\n",
"outer_loss = loss_fn(net(x), y)\n",
- "TorchOpt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(imgplt.imread('full_graph.png'))\n",
"meta_optim.zero_grad()\n",
@@ -306,7 +306,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "From the graph we can see, directly conducting the second bi-level process links the graph of first and second bi-level process together. We should manually stop gradient with `TorchOpt.stop_gradient`. `TorchOpt.stop_gradient` will detach the node of gradient graph and make it become a leaf node. It allows the input of network, optimizer, or state dictionary and the gradient operation happens in an inplace manner.\n",
+ "From the graph we can see, directly conducting the second bi-level process links the graph of first and second bi-level process together. We should manually stop gradient with `torchopt.stop_gradient`. `torchopt.stop_gradient` will detach the node of gradient graph and make it become a leaf node. It allows the input of network, optimizer, or state dictionary and the gradient operation happens in an inplace manner.\n",
"\n",
"Let's use recover_state_dict to come back to one-step updated states."
]
@@ -318,8 +318,8 @@
"outputs": [],
"source": [
"# Reset to previous one-step updated states\n",
- "TorchOpt.recover_state_dict(net, one_step_net_state)\n",
- "TorchOpt.recover_state_dict(optim, one_step_optim_state)"
+ "torchopt.recover_state_dict(net, one_step_net_state)\n",
+ "torchopt.recover_state_dict(optim, one_step_optim_state)"
]
},
{
@@ -356,14 +356,14 @@
],
"source": [
"# stop gradient and make them become the leaf node\n",
- "TorchOpt.stop_gradient(net)\n",
- "TorchOpt.stop_gradient(optim)\n",
+ "torchopt.stop_gradient(net)\n",
+ "torchopt.stop_gradient(optim)\n",
"\n",
"inner_loss = loss_fn(net(x), y)\n",
"loss = inner_loss * meta_parameter\n",
"optim.step(loss)\n",
"outer_loss = loss_fn(net(x), y)\n",
- "TorchOpt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
+ "torchopt.visual.make_dot(outer_loss).render(\"full_graph\", format=\"png\")\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(imgplt.imread('full_graph.png'))\n",
"meta_optim.zero_grad()\n",