torchfitter
is a simple library to ease the training of PyTorch models. It
features a class called Trainer
that includes the basic functionality to fit
models in a Keras-like style.
Internally, torchfitter
leverages the power of accelerate
to handle the device management.
The library also provides a callbacks API that can be used to interact with the model during the training process, as well as a set of basic regularization procedures.
Normal user
pip install torchfitter
This library does not ship CUDA nor XLA. Follow the official PyTorch documentation for more information about how to install CUDA binaries.
Developer
git clone https://github.com/Xylambda/torchfitter.git
pip install -e torchfitter/. -r torchfitter/requirements-dev.txt
To run the tests you must install the library as a developer
.
cd torchfitter/
pytest -v tests/
Supported | Not supported | Planned | |
---|---|---|---|
Basic training loop | x | ||
Gradient Clipping | x | ||
Gradient Accumulation | x | ||
Multi-device support | x | ||
Regularization | x | ||
In-loop metrics support | x | ||
Mixed precision training | x | ||
Callbacks System | x | ||
Hyperparameter search | x | ||
Warm Training | x | x |
Assume we already have DataLoaders
for the train and validation sets.
from torch.utils.data import DataLoader
train_loader = DataLoader(...)
val_loader = DataLoader(...)
Then, create the optimizer and the loss criterion as usual. Pass them to the trainer along the PyTorch model. You can also add a regularization procedure if you need/want to do it. The same goes for callbacks: create the desired callbacks and pass them to the trainer as a list.
import torch.nn as nn
import torch.optim as optim
from torchfitter.trainer import Trainer
from torchfitter.callbacks import (
LoggerCallback,
EarlyStopping,
LearningRateScheduler,
L1Regularization,
)
model = nn.Linear(in_features=1, out_features=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
l1_reg = L1Regularization(regularization_rate=0.01, biases=False)
# callbacks
logger = LoggerCallback(update_step=50)
early_stopping = EarlyStopping(patience=50, load_best=True, path='checkpoint.pt')
scheduler = LearningRateScheduler(
scheduler=optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.9)
)
trainer = Trainer(
model=model,
criterion=criterion,
optimizer=optimizer,
mixed_precision="fp16",
accumulate_iter=4, # accumulate gradient every 4 iterations,
gradient_clipping='norm',
gradient_clipping_kwrgs={'max_norm': 1.0, 'norm_type': 2.0},
callbacks=[l1_reg, scheduler, early_stopping, logger]
)
history = trainer.fit(train_loader, val_loader, epochs=1000)
Since torchfitter
leverages the power of accelerate
, the device management
will rely on the latter. You can pass your own accelerate.Accelerator
object to fine tune its parameters:
from accelerate import Accelerator
from torchfitter.trainer import Trainer
accelerator = Accelerator(...)
trainer = Trainer(
accelerator=accelerator,
**kwargs
)
Callbacks allow you to interact with the model during the fitting process. They provide with different methods that are called at different stages. To create a callback simply extend the base class and fill the desired methods.
import torch
from torchfitter.conventions import ParamsDict
from torchfitter.callbacks.base import Callback
class ModelCheckpoint(Callback):
def __init__(self, path):
super(ModelCheckpoint, self).__init__()
self.path = path
def __repr__(self) -> str:
return "ModelCheckpoint()"
def on_epoch_end(self, params_dict):
# get params
accelerator = params_dict[ParamsDict.ACCELERATOR]
epoch = params_dict[ParamsDict.EPOCH_NUMBER]
# ensure model is safe to save
_model = params_dict[ParamsDict.MODEL]
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(_model)
# actual saving
fname = self.path / f'model_epoch_{epoch}.pt'
accelerator.save(unwrapped_model.state_dict(), fname)
Each method receives params_dict
, which is a dictionary object containing the
internal training parameters. You can see the pair key value of each parameter
of the conventions:
>>> from torchfitter.conventions import ParamsDict
>>> [(x, getattr(ParamsDict, x)) for x in ParamsDict.__dict__ if not x.startswith('__')]
And you can also check the doc to understand the meaning of each one of the parameters:
>>> from torchfitter.conventions import ParamsDict
>>> print(ParamsDict.__doc__)
NOTE:
the callbacks design can be considered as a port from Keras design.
I AM NOT
the author of this callback sysem design despite the fact that I
made some minor design changes. Find more in the Credits
section.
- Do you know Pytorch-Lightning/FastAI?
I know them and I think they are awesome. This is a personal project though I must say the trainer is reasonably well-equiped.
- Why is the
validation loader
not optional?
Because I think it enforces good ML practices that way.
- Why didn't you implement the optimization steps in the model object?
It is certainly another approach you may take when building an optimization loop (PyTorch-Lightning works this way), but I don't like my abstract data types to track way too many things in addition to being torch.nn.Module types. Functionality should be clear and atomic: the model tracks gradients and the trainer cares about the optimization process.
- I have a suggestion/question
Thank you! Do not hesitate to open an issue and I'll do my best to answer you.
If you've used this library for your projects please cite it:
@misc{alejandro2019torchfitter,
title={torchfitter - Simple Trainer to Optimize PyTorch Models},
author={Alejandro Pérez-Sanjuán},
year={2020},
howpublished={\url{https://github.com/Xylambda/torchfitter}},
}