Skip to content

Commit a7a5992

Browse files
soulitzerfacebook-github-bot
authored andcommittedMay 25, 2021
Add no-grad inference mode note (pytorch#58513)
Summary: Adds a note explaining the difference between several often conflated mechanisms in the autograd note Also adds a link to this note from the docs in `grad_mode` and `nn.module`. Pull Request resolved: pytorch#58513 Reviewed By: gchanan Differential Revision: D28651129 Pulled By: soulitzer fbshipit-source-id: af9eb1749b641fc1b632815634eea36bf7979156
1 parent 5268b5a commit a7a5992

File tree

6 files changed

+179
-53
lines changed

6 files changed

+179
-53
lines changed
 

‎docs/cpp/source/notes/inference_mode.rst

-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ Inside an ``InferenceMode`` block, we make the following performance guarantees:
3030
- Inplace operations on inference tensors are guaranteed not to do a version bump.
3131

3232
For more implementation details of ``InferenceMode`` please see the `RFC-0011-InferenceMode <https://github.com/pytorch/rfcs/pull/17>`_.
33-
Currently this guard is only available in C++ frontend, adding python frontend support
34-
is tracked in #56608.
3533

3634
Migration guide from ``AutoNonVariableTypeMode``
3735
------------------------------------------------

‎docs/source/autograd.rst

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), inp
5050
Locally disabling gradient computation
5151
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5252

53+
See :ref:`locally-disable-grad-doc` for more information on the differences
54+
between no-grad and inference mode as well as other related mechanisms that
55+
may be confused with the two.
56+
5357
.. autosummary::
5458
:toctree: generated
5559
:nosignatures:

‎docs/source/notes/autograd.rst

+151-50
Original file line numberDiff line numberDiff line change
@@ -8,56 +8,6 @@ operations. It's not strictly necessary to understand all this, but we recommend
88
getting familiar with it, as it will help you write more efficient, cleaner
99
programs, and can aid you in debugging.
1010

11-
.. _excluding-subgraphs:
12-
13-
Excluding subgraphs from backward
14-
---------------------------------
15-
16-
Every Tensor has a flag: :attr:`requires_grad` that allows for fine grained
17-
exclusion of subgraphs from gradient computation and can increase efficiency.
18-
19-
.. _excluding-requires_grad:
20-
21-
``requires_grad``
22-
^^^^^^^^^^^^^^^^^
23-
24-
If there's a single input to an operation that requires gradient, its output
25-
will also require gradient. Conversely, only if all inputs don't require
26-
gradient, the output also won't require it. Backward computation is never
27-
performed in the subgraphs, where all Tensors didn't require gradients.
28-
29-
.. code::
30-
31-
>>> x = torch.randn(5, 5) # requires_grad=False by default
32-
>>> y = torch.randn(5, 5) # requires_grad=False by default
33-
>>> z = torch.randn((5, 5), requires_grad=True)
34-
>>> a = x + y
35-
>>> a.requires_grad
36-
False
37-
>>> b = a + z
38-
>>> b.requires_grad
39-
True
40-
41-
This is especially useful when you want to freeze part of your model, or you
42-
know in advance that you're not going to use gradients w.r.t. some parameters.
43-
For example if you want to finetune a pretrained CNN, it's enough to switch the
44-
:attr:`requires_grad` flags in the frozen base, and no intermediate buffers will
45-
be saved, until the computation gets to the last layer, where the affine
46-
transform will use weights that require gradient, and the output of the network
47-
will also require them.
48-
49-
.. code::
50-
51-
model = torchvision.models.resnet18(pretrained=True)
52-
for param in model.parameters():
53-
param.requires_grad = False
54-
# Replace the last fully-connected layer
55-
# Parameters of newly constructed modules have requires_grad=True by default
56-
model.fc = nn.Linear(512, 100)
57-
58-
# Optimize only the classifier
59-
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
60-
6111
.. _how-autograd-encodes-history:
6212

6313
How autograd encodes the history
@@ -86,6 +36,157 @@ flow statements, that can change the overall shape and size of the graph at
8636
every iteration. You don't have to encode all possible paths before you
8737
launch the training - what you run is what you differentiate.
8838

39+
.. _locally-disable-grad-doc:
40+
41+
Locally disabling gradient computation
42+
--------------------------------------
43+
44+
There are several mechanisms available from Python to locally disable gradient
45+
computation:
46+
47+
To disable gradients across entire blocks of code, there are context managers
48+
like no-grad mode and inference mode.
49+
For more fine-grained exclusion of subgraphs from gradient computation,
50+
there is setting the ``requires_grad`` field of a tensor.
51+
52+
Below, in addition to discussing the mechanisms above, we also describe
53+
evaluation mode (:meth:`nn.Module.eval()`), a method that is not actually used
54+
to disable gradient computation but, because of its name, is often mixed up with the three.
55+
56+
Setting ``requires_grad``
57+
^^^^^^^^^^^^^^^^^^^^^^^^^
58+
59+
:attr:`requires_grad` is a flag that allows for fine-grained exclusion of
60+
subgraphs from gradient computation. It takes effect in both the forward
61+
and backward passes:
62+
63+
During the forward pass, an operation is only recorded in the backward graph if
64+
at least one of its input tensors require grad.
65+
During the backward pass (``.backward()``), only leaf tensors with
66+
``requires_grad=True`` will have gradients accumulated into their ``.grad``
67+
fields.
68+
69+
It is important to note that even though every tensor has this flag,
70+
*setting* it only makes sense for leaf tensors (tensors that do not have a
71+
``grad_fn``, e.g., a ``nn.Module``'s parameters).
72+
Non-leaf tensors (tensors that do have ``grad_fn``) are tensors that have a
73+
backward graph associated with them. Thus their gradients will be needed
74+
as an intermediary result to compute the gradient for a leaf tensor that
75+
requires grad. From this definition, it is clear that all non-leaf tensors
76+
will automatically have ``require_grad=True``.
77+
78+
Setting ``requires_grad`` should be the main way you control which parts
79+
of the model are part of the gradient computation, for example, if you need to
80+
freeze parts of your pretrained model during model fine-tuning.
81+
82+
To freeze parts of your model, simply apply ``.requires_grad_(False)`` to
83+
the parameters that you don't want updated. And as described above,
84+
since computations that use these parameters as inputs would not be recorded in
85+
the forward pass, they won't have their ``.grad`` fields updated in the backward
86+
pass because they won't be part of the backward graph in the first place, as
87+
desired.
88+
89+
Because this is such a common pattern, ``requires_grad`` can also be set at
90+
the module level with :meth:`nn.Module.requires_grad_()`.
91+
When applied to a module, ``.requires_grad_()`` takes effect on all
92+
of the module's parameters (which have ``requires_grad=True`` by default).
93+
94+
Grad Modes
95+
^^^^^^^^^^
96+
97+
Apart from setting ``requires_grad`` there are also three possible modes
98+
enableable from Python that can affect how computations in PyTorch are
99+
processed by autograd internally: default mode (grad mode), no-grad mode,
100+
and inference mode, all of which can be togglable via context managers and
101+
decorators.
102+
103+
Default Mode (Grad Mode)
104+
^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
The "default mode" is actually the mode we are implicitly in when no other modes like
107+
no-grad and inference mode are enabled. To be contrasted with
108+
"no-grad mode" the default mode is also sometimes called "grad mode".
109+
110+
The most important thing to know about the default mode is that it is the only
111+
mode in which ``requires_grad`` takes effect. ``requires_grad`` is always overridden
112+
to be ``False`` in both the two other modes.
113+
114+
No-grad Mode
115+
^^^^^^^^^^^^
116+
117+
Computations in no-grad mode behave as if none of the inputs require grad.
118+
In other words, computations in no-grad mode are never recorded in the backward graph
119+
even if there are inputs that have ``require_grad=True``.
120+
121+
Enable no-grad mode when you need to perform operations that should not be
122+
recorded by autograd, but you’d still like to use the outputs of these
123+
computations in grad mode later. This context manager makes it convenient to
124+
disable gradients for a block of code or function without
125+
having to temporarily set tensors to have ``requires_grad=False``, and then
126+
back to ``True``.
127+
128+
For example, no-grad mode might be useful when writing an optimizer: when
129+
performing the training update you’d like to update parameters
130+
in-place without the update being recorded by autograd.
131+
You also intend to use the updated parameters for computations in
132+
grad mode in the next forward pass.
133+
134+
The implementations in :ref:`nn-init-doc` also
135+
rely on no-grad mode when initializing the parameters as to avoid
136+
autograd tracking when updating the intialized parameters in-place.
137+
138+
Inference Mode
139+
^^^^^^^^^^^^^^
140+
141+
Inference mode is the extreme version of no-grad mode. Just like in no-grad
142+
mode, computations in inference mode are not recorded in the backward graph, but
143+
enabling inference mode will allow PyTorch to speed up your model even more.
144+
This better runtime comes with a drawback: tensors created in inference mode
145+
will not be able to be used in computations to be recorded by autograd after
146+
exiting inference mode.
147+
148+
Enable inference mode when you are performing computations that don’t need
149+
to be recorded in the backward graph, AND you don’t plan on using the tensors
150+
created in inference mode in any computation that is to be recorded by autograd later.
151+
152+
It is recommended that you try out inference mode in the parts of your code
153+
that do not require autograd tracking (e.g., data processing and model evaluation).
154+
If it works out of the box
155+
for your use case it’s a free performance win. If you run into errors after
156+
enabling inference mode, check that you are not using tensors created in
157+
inference mode in computations that are recorded by autograd after exiting inference
158+
mode. If you cannot avoid such use in your case, you can always switch back
159+
to no-grad mode.
160+
161+
For details on inference mode please see
162+
`Inference Mode <https://pytorch.org/cppdocs/notes/inference_mode.html>`_.
163+
164+
For implementation details of inference mode see
165+
`RFC-0011-InferenceMode <https://github.com/pytorch/rfcs/pull/17>`_.
166+
167+
Evaluation Mode (``nn.Module.eval()``)
168+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
169+
170+
Evaluation mode is not actually a mechanism to locally disable gradient computation.
171+
It is included here anyway because it is sometimes confused to be such a mechanism.
172+
173+
Functionally, ``module.eval()`` (or equivalently ``module.train()``) are completely
174+
orthogonal to no-grad mode and inference mode. How ``model.eval()`` affects
175+
your model depends entirely on the specific modules used in your model and
176+
whether they define any training-mode specific behavior.
177+
178+
You are responsible for calling ``model.eval()`` and ``model.train()`` if your
179+
model relies on modules such as :class:`torch.nn.Dropout` and
180+
:class:`torch.nn.BatchNorm2d` that may behave
181+
differently depending on training mode, for example, to avoid updating your
182+
BatchNorm running statistics on validation data.
183+
184+
It is recommended that you always use ``model.train()`` when
185+
training and ``model.eval()`` when evaluating your model (validation/testing) even
186+
if you aren’t sure your model has training-mode specific behavior, because a
187+
module you are using might be updated to behave differently in training and
188+
eval modes.
189+
89190
In-place operations with autograd
90191
---------------------------------
91192

‎torch/autograd/grad_mode.py

+17
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class no_grad(_DecoratorContextManager):
9797
9898
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
9999
100+
.. note::
101+
No-grad is one of several mechanisms that can enable or
102+
disable gradients locally see :ref:`locally-disable-grad-doc` for
103+
more information on how they compare.
100104
101105
Example::
102106
@@ -136,6 +140,10 @@ class enable_grad(_DecoratorContextManager):
136140
137141
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
138142
143+
.. note::
144+
enable_grad is one of several mechanisms that can enable or
145+
disable gradients locally see :ref:`locally-disable-grad-doc` for
146+
more information on how they compare.
139147
140148
Example::
141149
@@ -178,6 +186,10 @@ class set_grad_enabled(object):
178186
(``False``). This can be used to conditionally enable
179187
gradients.
180188
189+
.. note::
190+
set_grad_enabled is one of several mechanisms that can enable or
191+
disable gradients locally see :ref:`locally-disable-grad-doc` for
192+
more information on how they compare.
181193
182194
Example::
183195
@@ -222,6 +234,11 @@ class inference_mode(_DecoratorContextManager):
222234
223235
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
224236
237+
.. note::
238+
Inference mode is one of several mechanisms that can enable or
239+
disable gradients locally see :ref:`locally-disable-grad-doc` for
240+
more information on how they compare.
241+
225242
Args:
226243
mode (bool): Flag whether to enable or disable inference mode
227244

‎torch/nn/modules/module.py

+6
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,9 @@ def eval(self: T) -> T:
16511651
16521652
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
16531653
1654+
See :ref:`locally-disable-grad-doc` for a comparison between
1655+
`.eval()` and several similar mechanisms that may be confused with it.
1656+
16541657
Returns:
16551658
Module: self
16561659
"""
@@ -1666,6 +1669,9 @@ def requires_grad_(self: T, requires_grad: bool = True) -> T:
16661669
This method is helpful for freezing part of the module for finetuning
16671670
or training parts of a model individually (e.g., GAN training).
16681671
1672+
See :ref:`locally-disable-grad-doc` for a comparison between
1673+
`.requires_grad_()` and several similar mechanisms that may be confused with it.
1674+
16691675
Args:
16701676
requires_grad (bool): whether autograd should record operations on
16711677
parameters in this module. Default: ``True``.

‎torch/nn/parameter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Parameter(torch.Tensor):
1818
Args:
1919
data (Tensor): parameter tensor.
2020
requires_grad (bool, optional): if the parameter requires gradient. See
21-
:ref:`excluding-subgraphs` for more details. Default: `True`
21+
:ref:`locally-disable-grad-doc` for more details. Default: `True`
2222
"""
2323
def __new__(cls, data=None, requires_grad=True):
2424
if data is None:

0 commit comments

Comments
 (0)