Custom backward that requires network input #16043
Replies: 2 comments 1 reply
-
I also want to try this balancer class with Lightning! Did you figure this out? |
Beta Was this translation helpful? Give feedback.
-
@joecomerisnotavailable @lukasschmit I found this issue while also exploring adapting the loss balancer mechanism from Encodec. Initially I thought that you needed to pass the network input as well, but as it turns out, the paper itself says you need to pass the output of the network. See the following: Then you can use the mechanism by overriding the default training step like usual. You must set
This seems to work for me at a first glance, but I'll report back if I encounter any other difficulties. |
Beta Was this translation helpful? Give feedback.
-
I am interested in using a loss-balancing class defined here:
https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py#L31
which has the expected usage
or in replicating its functionality in a way that cooperates with Pytorch Lightning and preferably still allows for mixed-precision training. The class's function is to re-weight per-loss gradients so that predefined weights for each loss correspond to that loss's proportion of contribution to the norm of the total gradient step.
Since the class replaces the usual backward pass, and requires the model's input as an argument, I'm not sure whether to overwrite the LightningModule's backward, or manual_backward, or bypass both in the closure defined in training_step, or if possibly an alternative implementation of the class's balancing utility that can be called using after_backward hook is required.
My main concern with bypassing or overwriting manual_backward is creating a major slowdown or silently breaking the mixed precision handling.
Thanks in advance for any help with this.
Beta Was this translation helpful? Give feedback.
All reactions