Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

add gradient clipping to create_supervised_trainer() #419

Open
lmarti opened this issue Jan 30, 2019 · 8 comments
Open

add gradient clipping to create_supervised_trainer() #419

lmarti opened this issue Jan 30, 2019 · 8 comments
Labels

Comments

@lmarti
Copy link

lmarti commented Jan 30, 2019

It would be good to add gradient clipping to the trainers created by create_supervised_trainer. This is already provided by torch.nn.utils.clip_grad_norm_.

One possible implementation could be:

import math
from torch.nn.utils import clip_grad_norm_

def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              gradient_clip=math.inf):
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        gradient_clip (float, optional): value to use to clip gradients.
    Note: `engine.state.output` for this engine is the loss of the processed batch.
    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()
        return loss.item()

    return Engine(_update)
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 30, 2019

@lmarti thanks for the feedback. We discussed a similar question in #375.
Methods like create_supervised_trainer are just helper methods for a basic usage, use directly Engine with custom process_fn.

We can discuss whether such trainer could be useful and placed in contrib.engines module.
cc @willprice

@lmarti
Copy link
Author

lmarti commented Jan 30, 2019

Sorry, I missed that one. I had the same doubts w.r.t. moving it to contrib.engines. My point against doing it is that the code would be so similar to the one in create_supervised_trainer. In any case, you are driving here.

@AntoinePrv
Copy link
Contributor

AntoinePrv commented Feb 1, 2019

A general way to maintain this would be to fire a new event (GRADIENT_COMPUTED?) between loss.backward() and optimizer.step()

Doesn't have to be added into core events, it can just be added for supervised_trainer as we did with supervised_tbptt_trainer.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 1, 2019

@AntoinePrv I think it would be more simple to write custom processing function instead of custom events.

@sudarshan85
Copy link

@vfdev-5 While I agree with you, it would be nice to have options. In particular, it would be great if we could have more events compared to the fastai callback system. The callbacks listed there are (events in parenthesis):

  1. on_train_begin() (Events.STARTED)
  2. on_epoch_begin() (Events.EPOCH_STARTED)
  3. on_batch_begin() (Events.ITERATION_STARTED)
  4. on_loss_begin()*: Called after forward pass but before loss has been computed.
  5. on_backward_begin()*: Called after forward pass and loss computation but before backprop.
  6. on_backward_end()*: Called after backprop but before optimizer step.
  7. on_step_end()*: Called after optimizer step but before gradients are zeroed.
  8. on_batch_end() (Events.ITERATION_COMPLETED)
  9. on_epoch_end() (Events.EPOCH_COMPLETED)
  10. on_train_end() (Events.COMPLETED)
  • these fastai callbacks have not corresponding ignite events. Having these as options provides the following advantages:
  1. It adds even more flexibility to the engine
  2. A lot of fastai's callbacks are utilized to provide tips and other advantages such as LRFinder, gradient clipping etc. It would be easy to port over those if we have these events.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 1, 2019

@sudarshan85 we can think about to provide a generic callback class into contrib module.
But I hardly imagine a class that uses all these on_* methods. The example you cited, LRFinder implements just 3 methods: on_train_begin, on_batch_end, on_train_end. This is very similar to the behaviour of our classes with attach method = handle 2-3 events of the Engine: Metric, ProgressBar etc.

@TilakSanghvi
Copy link

@lmarti I am interested in this issue and would like to contribute in this issue. Please assign me this issue.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 15, 2024

@TilakSanghvi you can start from this PR : #1693 and add tests

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
5 participants