Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

set max_shard_size to None for pipeline save_pretrained #9447

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def save_pretrained(
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
**kwargs,
):
Expand All @@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"10GB"`):
max_shard_size (`int` or `str`, defaults to `None`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
Expand Down Expand Up @@ -293,7 +293,8 @@ def is_saveable_module(name, value):
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
if save_method_accept_max_shard_size:
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size

save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
Expand Down
Loading