@@ -189,7 +189,7 @@ def save_pretrained(
189
189
save_directory : Union [str , os .PathLike ],
190
190
safe_serialization : bool = True ,
191
191
variant : Optional [str ] = None ,
192
- max_shard_size : Union [int , str ] = "10GB" ,
192
+ max_shard_size : Optional [ Union [int , str ]] = None ,
193
193
push_to_hub : bool = False ,
194
194
** kwargs ,
195
195
):
@@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
205
205
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
206
206
variant (`str`, *optional*):
207
207
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
208
- max_shard_size (`int` or `str`, defaults to `"10GB" `):
208
+ max_shard_size (`int` or `str`, defaults to `None `):
209
209
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
210
210
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
211
211
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
@@ -293,7 +293,8 @@ def is_saveable_module(name, value):
293
293
save_kwargs ["safe_serialization" ] = safe_serialization
294
294
if save_method_accept_variant :
295
295
save_kwargs ["variant" ] = variant
296
- if save_method_accept_max_shard_size :
296
+ if save_method_accept_max_shard_size and max_shard_size is not None :
297
+ # max_shard_size is expected to not be None in ModelMixin
297
298
save_kwargs ["max_shard_size" ] = max_shard_size
298
299
299
300
save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
0 commit comments