Skip to content

Commit

Permalink
Try to fix mixed precision autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Oct 1, 2024
1 parent e52747f commit 0870d30
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def predict(self, inputs: ArgsKwargs) -> Any:
"""
self._model.eval()
with torch.no_grad():
with torch.amp.autocast("cuda", self._mixed_precision):
# NB: Previously this was torch.cuda.amp.autocast, passing a boolean
# for mixed_precision. That doesn't seem to match the docs, and now
# it raises an error when moving from the deprecated function. So
# I've removed the argument but I'm not certain it's correct.
with torch.autocast(device_type="cuda"):
outputs = self._model(*inputs.args, **inputs.kwargs)
self._model.train()
return outputs
Expand All @@ -125,7 +129,11 @@ def begin_update(self, inputs: ArgsKwargs):
self._model.train()

# Note: mixed-precision autocast must not be applied to backprop.
with torch.amp.autocast("cuda", self._mixed_precision):
# NB: Previously this was torch.cuda.amp.autocast, passing a boolean
# for mixed_precision. That doesn't seem to match the docs, and now
# it raises an error when moving from the deprecated function. So
# I've removed the argument but I'm not certain it's correct.
with torch.autocast("cuda"):
output = self._model(*inputs.args, **inputs.kwargs)

def backprop(grads):
Expand Down

0 comments on commit 0870d30

Please # to comment.