Skip to content

Commit cb3a0de

Browse files
a-r-r-o-wyiyixuxu
authored andcommitted
set max_shard_size to None for pipeline save_pretrained (#9447)
* update default max_shard_size * add None check to fix tests --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 0494cbd commit cb3a0de

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def save_pretrained(
189189
save_directory: Union[str, os.PathLike],
190190
safe_serialization: bool = True,
191191
variant: Optional[str] = None,
192-
max_shard_size: Union[int, str] = "10GB",
192+
max_shard_size: Optional[Union[int, str]] = None,
193193
push_to_hub: bool = False,
194194
**kwargs,
195195
):
@@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
205205
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
206206
variant (`str`, *optional*):
207207
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
208-
max_shard_size (`int` or `str`, defaults to `"10GB"`):
208+
max_shard_size (`int` or `str`, defaults to `None`):
209209
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
210210
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
211211
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
@@ -293,7 +293,8 @@ def is_saveable_module(name, value):
293293
save_kwargs["safe_serialization"] = safe_serialization
294294
if save_method_accept_variant:
295295
save_kwargs["variant"] = variant
296-
if save_method_accept_max_shard_size:
296+
if save_method_accept_max_shard_size and max_shard_size is not None:
297+
# max_shard_size is expected to not be None in ModelMixin
297298
save_kwargs["max_shard_size"] = max_shard_size
298299

299300
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)

0 commit comments

Comments
 (0)