Skip to content

Commit 498b191

Browse files
committed
apply review suggestions
1 parent bc9fc27 commit 498b191

File tree

12 files changed

+80
-109
lines changed

12 files changed

+80
-109
lines changed

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@
261261
"WanTransformer3DModel",
262262
"WanVACETransformer3DModel",
263263
"attention_backend",
264-
"enable_parallelism",
265264
]
266265
)
267266
_import_structure["modular_pipelines"].extend(
@@ -942,7 +941,6 @@
942941
WanTransformer3DModel,
943942
WanVACETransformer3DModel,
944943
attention_backend,
945-
enable_parallelism,
946944
)
947945
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
948946
from .optimization import (

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_import_structure = {}
2626

2727
if is_torch_available():
28-
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig", "enable_parallelism"]
28+
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
2929
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
3030
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
3131
_import_structure["auto_model"] = ["AutoModel"]
@@ -120,7 +120,7 @@
120120

121121
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
122122
if is_torch_available():
123-
from ._modeling_parallel import ContextParallelConfig, ParallelConfig, enable_parallelism
123+
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
124124
from .adapter import MultiAdapter, T2IAdapter
125125
from .attention_dispatch import AttentionBackendName, attention_backend
126126
from .auto_model import AutoModel

src/diffusers/models/_modeling_parallel.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -241,62 +241,3 @@ def __repr__(self):
241241
#
242242
# ContextParallelOutput:
243243
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
244-
245-
_ENABLE_PARALLELISM_WARN_ONCE = False
246-
247-
248-
@contextlib.contextmanager
249-
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
250-
"""
251-
A context manager to set the parallelism context for models or pipelines that have been parallelized.
252-
253-
Args:
254-
model_or_pipeline (`DiffusionPipeline` or `ModelMixin`):
255-
The model or pipeline to set the parallelism context for. The model or pipeline must have been parallelized
256-
with `.enable_parallelism(ParallelConfig(...), ...)` before using this context manager.
257-
"""
258-
259-
from diffusers import DiffusionPipeline, ModelMixin
260-
261-
from .attention_dispatch import _AttentionBackendRegistry
262-
263-
global _ENABLE_PARALLELISM_WARN_ONCE
264-
if not _ENABLE_PARALLELISM_WARN_ONCE:
265-
logger.warning(
266-
"Support for `enable_parallelism` is experimental and the API may be subject to change in the future."
267-
)
268-
_ENABLE_PARALLELISM_WARN_ONCE = True
269-
270-
if isinstance(model_or_pipeline, DiffusionPipeline):
271-
parallelized_components = [
272-
(name, component)
273-
for name, component in model_or_pipeline.components.items()
274-
if getattr(component, "_parallel_config", None) is not None
275-
]
276-
if len(parallelized_components) > 1:
277-
raise ValueError(
278-
"Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run "
279-
"different stages of the pipeline separately with `enable_parallelism` on each component manually."
280-
)
281-
if len(parallelized_components) == 0:
282-
raise ValueError(
283-
"No parallelized components found in the pipeline. Please ensure at least one component is parallelized."
284-
)
285-
_, model_or_pipeline = parallelized_components[0]
286-
elif isinstance(model_or_pipeline, ModelMixin):
287-
if getattr(model_or_pipeline, "_parallel_config", None) is None:
288-
raise ValueError(
289-
"The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
290-
)
291-
else:
292-
raise TypeError(
293-
f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline."
294-
)
295-
296-
# TODO: needs to be updated when more parallelism strategies are supported
297-
old_parallel_config = _AttentionBackendRegistry._parallel_config
298-
_AttentionBackendRegistry._parallel_config = model_or_pipeline._parallel_config.context_parallel_config
299-
300-
yield
301-
302-
_AttentionBackendRegistry._parallel_config = old_parallel_config

0 commit comments

Comments
 (0)