Skip to content

Commit

Permalink
Add weights_only parameter to default_deserialize_torch_model for enh…
Browse files Browse the repository at this point in the history
…anced security (#950)
  • Loading branch information
NripeshN authored Dec 11, 2024
1 parent 1f4b5f0 commit 256e403
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def default_serialize_torch_model(model: Any) -> bytes:


def default_deserialize_torch_model(
model: Any, state_bytes: bytes, device: "torch.device"
model: Any, state_bytes: bytes, device: "torch.device", weights_only: bool = True
) -> Any:
"""Deserializes the parameters of the wrapped PyTorch model and
moves it to the specified device.
Expand All @@ -244,12 +244,15 @@ def default_deserialize_torch_model(
Serialized parameters as a byte stream.
device:
PyTorch device to which the model is bound.
weights_only:
Whether to only load the model's weights (default: True). Setting this
to True enhances security and avoids loading arbitrary objects.
Returns:
The deserialized model.
"""
filelike = BytesIO(state_bytes)
filelike.seek(0)
model.load_state_dict(torch.load(filelike, map_location=device))
state_dict = torch.load(filelike, map_location=device, weights_only=weights_only)
model.to(device)
return model

0 comments on commit 256e403

Please # to comment.