diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 3d18daf0f..0efe47a01 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -233,7 +233,7 @@ def default_serialize_torch_model(model: Any) -> bytes: def default_deserialize_torch_model( - model: Any, state_bytes: bytes, device: "torch.device" + model: Any, state_bytes: bytes, device: "torch.device", weights_only: bool = True ) -> Any: """Deserializes the parameters of the wrapped PyTorch model and moves it to the specified device. @@ -244,12 +244,15 @@ def default_deserialize_torch_model( Serialized parameters as a byte stream. device: PyTorch device to which the model is bound. + weights_only: + Whether to only load the model's weights (default: True). Setting this + to True enhances security and avoids loading arbitrary objects. Returns: The deserialized model. """ filelike = BytesIO(state_bytes) filelike.seek(0) - model.load_state_dict(torch.load(filelike, map_location=device)) + state_dict = torch.load(filelike, map_location=device, weights_only=weights_only) model.to(device) return model