diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 3d18daf0f..cfe90b57a 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -111,7 +111,11 @@ def predict(self, inputs: ArgsKwargs) -> Any: """ self._model.eval() with torch.no_grad(): - with torch.amp.autocast("cuda", self._mixed_precision): + # NB: Previously this was torch.cuda.amp.autocast, passing a boolean + # for mixed_precision. That doesn't seem to match the docs, and now + # it raises an error when moving from the deprecated function. So + # I've removed the argument but I'm not certain it's correct. + with torch.autocast(device_type="cuda"): outputs = self._model(*inputs.args, **inputs.kwargs) self._model.train() return outputs @@ -125,7 +129,11 @@ def begin_update(self, inputs: ArgsKwargs): self._model.train() # Note: mixed-precision autocast must not be applied to backprop. - with torch.amp.autocast("cuda", self._mixed_precision): + # NB: Previously this was torch.cuda.amp.autocast, passing a boolean + # for mixed_precision. That doesn't seem to match the docs, and now + # it raises an error when moving from the deprecated function. So + # I've removed the argument but I'm not certain it's correct. + with torch.autocast("cuda"): output = self._model(*inputs.args, **inputs.kwargs) def backprop(grads):