Skip to content

Commit 2454b98

Browse files
authored
Allow max shard size to be specified when saving pipeline (huggingface#9440)
allow max shard size to be specified when saving pipeline
1 parent 37e3603 commit 2454b98

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def save_pretrained(
189189
save_directory: Union[str, os.PathLike],
190190
safe_serialization: bool = True,
191191
variant: Optional[str] = None,
192+
max_shard_size: Union[int, str] = "10GB",
192193
push_to_hub: bool = False,
193194
**kwargs,
194195
):
@@ -204,6 +205,13 @@ class implements both a save and loading method. The pipeline is easily reloaded
204205
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
205206
variant (`str`, *optional*):
206207
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
208+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
209+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
210+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
211+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
212+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
213+
This is to establish a common default size for this argument across different libraries in the Hugging
214+
Face ecosystem (`transformers`, and `accelerate`, for example).
207215
push_to_hub (`bool`, *optional*, defaults to `False`):
208216
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
209217
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -278,12 +286,15 @@ def is_saveable_module(name, value):
278286
save_method_signature = inspect.signature(save_method)
279287
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
280288
save_method_accept_variant = "variant" in save_method_signature.parameters
289+
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
281290

282291
save_kwargs = {}
283292
if save_method_accept_safe:
284293
save_kwargs["safe_serialization"] = safe_serialization
285294
if save_method_accept_variant:
286295
save_kwargs["variant"] = variant
296+
if save_method_accept_max_shard_size:
297+
save_kwargs["max_shard_size"] = max_shard_size
287298

288299
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
289300

0 commit comments

Comments
 (0)