Skip to content

Commit 4b2fcc1

Browse files
committed
support passing parallel_config to from_pretrained
1 parent b85c26c commit 4b2fcc1

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
963963
quantization_config = kwargs.pop("quantization_config", None)
964964
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
965965
disable_mmap = kwargs.pop("disable_mmap", False)
966+
parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
966967

967968
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
968969
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1343,6 +1344,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
13431344
# Set model in evaluation mode to deactivate DropOut modules by default
13441345
model.eval()
13451346

1347+
if parallel_config is not None:
1348+
model.enable_parallelism(config=parallel_config)
1349+
13461350
if output_loading_info:
13471351
return model, loading_info
13481352

0 commit comments

Comments
 (0)