note - I've recently revamped the internals to be significantly cleaner and allow for projections (e.g. GaLore, GWT), new docs are WIP but should be finished soon after I re-add all modules.
torchzero
implements a large number of chainable optimization modules that can be chained together to create custom optimizers:
import torchzero as tz
optimizer = tz.Modular(
model.parameters(),
tz.m.Adam(),
tz.m.Cautious(),
tz.m.LR(1e-3),
tz.m.WeightDecay(1e-4)
)
# standard training loop
for batch in dataset:
preds = model(batch)
loss = criterion(preds)
optimizer.zero_grad()
optimizer.step()
Each module takes the output of the previous module and applies a further transformation. This modular design avoids redundant code, such as reimplementing cautioning, orthogonalization, laplacian smoothing, etc for every optimizer. It is also easy to experiment with grafting, interpolation between different optimizers, and perhaps some weirder combinations like nested momentum.
Modules are not limited to gradient transformations. They can perform other operations like line searches, exponential moving average (EMA) and stochastic weight averaging (SWA), gradient accumulation, gradient approximation, and more.
There are over 100 modules, all accessible within the tz.m
namespace. For example, the Adam update rule is available as tz.m.Adam
. Complete list of modules is available in documentation.