@@ -189,6 +189,7 @@ def save_pretrained(
189189 save_directory : Union [str , os .PathLike ],
190190 safe_serialization : bool = True ,
191191 variant : Optional [str ] = None ,
192+ max_shard_size : Union [int , str ] = "10GB" ,
192193 push_to_hub : bool = False ,
193194 ** kwargs ,
194195 ):
@@ -204,6 +205,13 @@ class implements both a save and loading method. The pipeline is easily reloaded
204205 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
205206 variant (`str`, *optional*):
206207 If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
208+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
209+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
210+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
211+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
212+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
213+ This is to establish a common default size for this argument across different libraries in the Hugging
214+ Face ecosystem (`transformers`, and `accelerate`, for example).
207215 push_to_hub (`bool`, *optional*, defaults to `False`):
208216 Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
209217 repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -278,12 +286,15 @@ def is_saveable_module(name, value):
278286 save_method_signature = inspect .signature (save_method )
279287 save_method_accept_safe = "safe_serialization" in save_method_signature .parameters
280288 save_method_accept_variant = "variant" in save_method_signature .parameters
289+ save_method_accept_max_shard_size = "max_shard_size" in save_method_signature .parameters
281290
282291 save_kwargs = {}
283292 if save_method_accept_safe :
284293 save_kwargs ["safe_serialization" ] = safe_serialization
285294 if save_method_accept_variant :
286295 save_kwargs ["variant" ] = variant
296+ if save_method_accept_max_shard_size :
297+ save_kwargs ["max_shard_size" ] = max_shard_size
287298
288299 save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
289300
0 commit comments