Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

chore: rename package to follow PEP8 naming convention #20

Merged
merged 11 commits into from
Jul 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# https://editorconfig.org/

root = true

[*]
charset = utf-8
end_of_line = lf
indent_style = space
indent_size = 4
trim_trailing_whitespace = true
insert_final_newline = true

[*.py]
indent_size = 4
src_paths=torchopt,tests,examples

[*.md]
indent_size = 2
x-soft-wrap-text = true

[*.rst]
indent_size = 4
x-soft-wrap-text = true

[Makefile]
indent_style = tab

[*.cpp]
indent_size = 2

[*.h]
indent_size = 2

[*.cuh?]
indent_size = 2
8 changes: 4 additions & 4 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Please try to provide a minimal example to reproduce the bug. Error messages and
Please use the markdown code blocks for both code and stack traces.

```python
import metarl
import torchopt
```

```bash
Expand All @@ -43,8 +43,8 @@ Describe the characteristic of your environment:
* Versions of any other relevant libraries

```python
import metarl, numpy, sys
print(metarl.__version__, numpy.__version__, sys.version, sys.platform)
import torchopt, numpy, sys
print(torchopt.__version__, numpy.__version__, sys.version, sys.platform)
```

## Additional context
Expand All @@ -58,5 +58,5 @@ If you know or suspect the reason for this bug, paste the code lines and suggest
## Checklist

- [ ] I have checked that there is no similar issue in the repo (**required**)
- [ ] I have read the [documentation](https://metarl.readthedocs.io/) (**required**)
- [ ] I have read the [documentation](https://torchopt.readthedocs.io/) (**required**)
- [ ] I have provided a minimal working example to reproduce the bug (**required**)
4 changes: 2 additions & 2 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax `close #15213` if this solves the issue #15213

- [ ] I have raised an issue to propose this change ([required](https://metarl.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes)
- [ ] I have raised an issue to propose this change ([required](https://torchopt.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes)

## Types of changes

Expand All @@ -32,7 +32,7 @@ What types of changes does your code introduce? Put an `x` in all the boxes that
Go over all the following points, and put an `x` in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

- [ ] I have read the [CONTRIBUTION](https://metarl.readthedocs.io/en/latest/pages/contributing.html) guide (**required**)
- [ ] I have read the [CONTRIBUTION](https://torchopt.readthedocs.io/en/latest/pages/contributing.html) guide (**required**)
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
Expand Down
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
.idea
build
__pycache__
TorchOpt/**/*.so
TorchOpt.egg-info
torchopt/**/*.so
torchopt.egg-info
dist
**/.ipynb_checkpoints/*

Expand Down Expand Up @@ -152,4 +152,4 @@ dmypy.json
.pytype/

# Cython debug symbols
cython_debug/
cython_debug/
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ authors:
version: 0.4.1
date-released: "2022-04-09"
license: Apache-2.0
repository-code: "https://github.com/metaopt/TorchOpt"
repository-code: "https://github.com/metaopt/torchopt"
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

cmake_minimum_required(VERSION 3.1)
project(TorchOpt LANGUAGES CXX CUDA)
project(torchopt LANGUAGES CXX CUDA)

find_package(CUDA REQUIRED)

Expand Down
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
print-% : ; @echo $* = $($*)
SHELL = /bin/bash
PROJECT_NAME = TorchOpt
PROJECT_NAME = torchopt
PROJECT_PATH = ${PROJECT_NAME}/
PROJECT_FOLDER = $(PROJECT_NAME) examples include src tests
PYTHON_FILES = $(shell find . -type f -name "*.py")
PYTHON_FILES = $(shell find examples torchopt tests -type f -name "*.py" -o -name "*.pyi")
CPP_FILES = $(shell find . -type f -name "*.h" -o -name "*.cpp" -o -name "*.cuh" -o -name "*.cu")
COMMIT_HASH = $(shell git log -1 --format=%h)
COPYRIGHT = "MetaOPT Team. All Rights Reserved."
Expand Down Expand Up @@ -66,7 +66,8 @@ flake8: flake8-install
flake8 $(PYTHON_FILES) --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics

py-format: py-format-install
isort --check $(PYTHON_FILES) && yapf -ir $(PYTHON_FILES)
isort --project torchopt --check $(PYTHON_FILES) && \
yapf --in-place --recursive $(PYTHON_FILES)

mypy: mypy-install
mypy $(PROJECT_NAME)
Expand Down Expand Up @@ -103,4 +104,3 @@ format: py-format-install clang-format-install
yapf -ir $(PYTHON_FILES)
clang-format-11 -style=file -i $(CPP_FILES)
addlicense -c $(COPYRIGHT) -l apache -y 2022 $(PROJECT_FOLDER)

71 changes: 36 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
</div>

**TorchOpt** is a high-performance optimizer library built upon [PyTorch](https://pytorch.org/) for easy implementation of functional optimization and gradient-based meta-learning. It consists of two main features:
- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
- TorchOpt provides functional optimizer which enables [JAX-like](https://github.com/google/jax) composable functional optimizer for PyTorch. With TorchOpt, one can easily conduct neural network optimization in PyTorch with functional style optimizer, similar to [Optax](https://github.com/deepmind/optax) in JAX.
- With the desgin of functional programing, TorchOpt provides efficient, flexible, and easy-to-implement differentiable optimizer for gradient-based meta-learning research. It largely reduces the efforts required to implement sophisticated meta-learning algorithms.

--------------------------------------------------------------------------------
Expand All @@ -21,43 +21,44 @@ The README is organized as follows:
- [Installation](#installation)
- [Future Plan](#future-plan)
- [The Team](#the-team)
- [Citing TorchOpt](#citing-torchopt)


## TorchOpt as Functional Optimizer
The desgin of TorchOpt follows the philosophy of functional programming. Aligned with [functorch](https://github.com/pytorch/functorch), users can conduct functional style programing with models, optimizers and training in PyTorch. We use the Adam optimizer as an example in the following illustration. You can also check out the tutorial notebook [Functional Optimizer](./tutorials/1_Functional_Optimizer.ipynb) for more details.
### Optax-Like API
For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `TorchOpt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with functorch:
For those users who prefer fully functional programing, we offer Optax-Like API by passing gradients and optimizers states to the optimizer function. We design base class `torchopt.Optimizer` that has the same interface as `torch.optim.Optimizer`. Here is an example coupled with functorch:
```python
import torch
from torch import nn
from torch import data
from nn import functional as F
import functorch
import TorchOpt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt
from torch.utils.data import DataLoader

class Net(nn.Module):...
class Net(nn.Module): ...

class Loader(data.DataLoader):...
class Loader(DataLoader): ...

net = Net() # init
loader = Loader()
optimizer = TorchOpt.adam()
optimizer = torchopt.adam()
func, params = functorch.make_functional(net) # use functorch extract network parameters
opt_state = optimizer.init(params) # init optimizer
xs, ys = next(loader) # get data
pred = func(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
loss = F.cross_entropy(pred, ys) # compute loss
grad = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grad, opt_state) # get updates
params = TorchOpt.apply_updates(params, updates) # update network parameters
params = torchopt.apply_updates(params, updates) # update network parameters
```
### PyTorch-Like API
We also offer origin PyTorch APIs (e.g. `zero_grad()` or `step()`) by warpping our Optax-Like API for traditional PyTorch user:
<!-- The functional programming can easily disguise as origin PyTorch APIs (e.g. `zero_grad()` or `step()`), the only we need is to build a new class that contains both the optimizer function and optimizer states. -->
```python
net = Net() # init
loader = Loader()
optimizer = TorchOpt.Adam(net.parameters())
optimizer = torchopt.Adam(net.parameters())
xs, ys = next(loader) # get data
pred = net(xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
Expand All @@ -71,15 +72,15 @@ On top of the same optimization function as `torch.optim`, an important benefit
# get updates
updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
# update network parameters
params = TorchOpt.apply_updates(params, updates, inplace=False)
params = torchopt.apply_updates(params, updates, inplace=False)
```
## TorchOpt as Differentiable Optimizer for Meta-Learning
Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimisation process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations.
Meta-Learning has gained enormous attention in both Supervised Learning and Reinforcement Learning. Meta-Learning algorithms often contain a bi-level optimisation process with *inner loop* updating the network parameters and *outer loop* updating meta parameters. The figure below illustrates the basic formulation for meta-optimization in Meta-Learning. The main feature is that the gradients of *outer loss* will back-propagate through all `inner.step` operations.
<div align="center">
<img src=/image/TorchOpt.png width=85% />
</div>

Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `nn.Module`) or optimizer (a.k.a. `optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [higher](https://github.com/facebookresearch/higher), [learn2learn](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API.
Since network parameters become a node of computation graph, a flexible Meta-Learning library should enable users manually control the gradient graph connection which means that users should have access to the network parameters and optimizer states for manually detaching or connecting the computation graph. In PyTorch designing, the network parameters or optimizer states are members of network (a.k.a. `nn.Module`) or optimizer (a.k.a. `optim.Optimizer`), this design significantly introducing difficulty for user control network parameters or optimizer states. Previous differentiable optimizer Repo [higher](https://github.com/facebookresearch/higher), [learn2learn](https://github.com/learnables/learn2learn) follows the PyTorch designing which leads to inflexible API.

In contrast to them, TorchOpt realizes differentiable optimizer with functional programing, where Meta-Learning researchers could control the network parameters or optimizer states as normal variables (a.k.a. `torch.Tensor`). This functional optimizer design of TorchOpt is beneficial for implementing complex gradient flow Meta-Learning algorithms and allow us to improve computational efficiency by using techniques like operator fusion.

Expand All @@ -91,8 +92,8 @@ We hope meta-learning researchers could control the network parameters or optimi

### Meta-Learning API
<!-- Meta-Learning algorithms often use *inner loop* to update network parameters and compute an *outer loss* then back-propagate the *outer loss*. So the optimizer used in the *inner loop* should be differentiable. Thanks to the functional design, we can easily realize this requirement. -->
- We design a base class `TorchOpt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
- We offer `TorchOpt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
- We design a base class `torchopt.MetaOptimizer` for managing network updates in Meta-Learning. The constructor of `MetaOptimizer` takes as input the network rather than network parameters. `MetaOptimizer` exposed interface `step(loss)` takes as input the loss for step the network parameter. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
- We offer `torchopt.chain` which can apply a list of chainable update transformations. Combined with `MetaOptimizer`, it can help you conduct gradient transformation such as gradient clip before the Meta optimizer steps. Refer to the tutorial notebook [Meta Optimizer](./tutorials/2_Meta_Optimizer.ipynb) for more details.
- We observe that different Meta-Learning algorithms vary in inner-loop parameter recovery. TorchOpt provides basic functions for users to extract or recover network parameters and optimizer states anytime anywhere they want.
- Some algorithms such as [MGRL](https://proceedings.neurips.cc/paper/2018/file/2715518c875999308842e3455eda2fe3-Paper.pdf) initialize the inner-loop parameters inherited from previous inner-loop process when conducting a new bi-level process. TorchOpt also provides a finer function `stop_gradient` for manipulating the gradient graph, which is helpful for this kind of algortihms. Refer to the notebook [Stop Gradient](./tutorials/4_Stop_Gradient.ipynb) for more details.

Expand All @@ -101,40 +102,40 @@ We give an example of [MAML](https://arxiv.org/abs/1703.03400) with inner-loop A
```python
net = Net() # init
# the constructor `MetaOptimizer` takes as input the network
inner_optim = TorchOpt.MetaAdam(net)
outer_optim = TorchOpt.Adam(net.parameters())
inner_optim = torchopt.MetaAdam(net)
outer_optim = torchopt.Adam(net.parameters())

for train_iter in range(train_iters):
outer_loss = 0
for task in range(tasks):
loader = Loader(tasks)

# store states at the inital points
net_state = TorchOpt.extract_state_dict(net) # extract state
optim_state = TorchOpt.extract_state_dict(inner_optim)
net_state = torchopt.extract_state_dict(net) # extract state
optim_state = torchopt.extract_state_dict(inner_optim)
for inner_iter in range(inner_iters):
# compute inner loss and perform inner update
xs, ys = next(loader)
pred = net(xs)
inner_loss = F.cross_entropy(pred, ys)
inner_loss = F.cross_entropy(pred, ys)
inner_optim.step(inner_loss)
# compute outer loss and back-propagate
xs, ys = next(loader)
xs, ys = next(loader)
pred = net(xs)
outer_loss += F.cross_entropy(pred, ys)

# recover network and optimizer states at the inital point for the next task
TorchOpt.recover_state_dict(inner_optim, optim_state)
TorchOpt.recover_state_dict(net, net_state)
torchopt.recover_state_dict(inner_optim, optim_state)
torchopt.recover_state_dict(net, net_state)

outer_loss /= len(tasks) # task average
outer_optim.zero_grad()
outer_loss.backward()
outer_optim.step()

# stop gradient if necessary
TorchOpt.stop_gradient(net)
TorchOpt.stop_gradient(inner_optim)
torchopt.stop_gradient(net)
torchopt.stop_gradient(inner_optim)
```
## Examples
In *examples/*, we offer serveral examples of functional optimizer and 5 light-weight meta-learning examples with TorchOpt. The meta-learning examples covers 2 Supervised Learning and 3 Reinforcement Learning algorithms.
Expand Down Expand Up @@ -168,13 +169,13 @@ Requirements
- (Optional) For visualizing computation graphs
- [Graphviz](https://graphviz.org/download/) (for Linux users use `apt/yum install graphviz` or `conda install -c anaconda python-graphviz`)
```bash
pip install TorchOpt
pip install torchopt
```

You can also build shared libraries from source, use:
```bash
git clone git@github.com:metaopt/TorchOpt.git
cd TorchOpt
git clone git@github.com:metaopt/torchopt.git
cd torchopt
python setup.py build_from_source
```
## Future Plan
Expand All @@ -196,6 +197,6 @@ If you find TorchOpt useful, please cite it in your publications.
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/metaopt/TorchOpt}},
howpublished = {\url{https://github.com/metaopt/torchopt}},
}
```
64 changes: 0 additions & 64 deletions TorchOpt/__init__.py

This file was deleted.

Loading