@@ -241,62 +241,3 @@ def __repr__(self):
241241# 
242242# ContextParallelOutput: 
243243#     specifies how to gather the input tensor in the post-forward hook in the layer it is attached to 
244- 
245- _ENABLE_PARALLELISM_WARN_ONCE  =  False 
246- 
247- 
248- @contextlib .contextmanager  
249- def  enable_parallelism (model_or_pipeline : Union ["DiffusionPipeline" , "ModelMixin" ]):
250-     """ 
251-     A context manager to set the parallelism context for models or pipelines that have been parallelized. 
252- 
253-     Args: 
254-         model_or_pipeline (`DiffusionPipeline` or `ModelMixin`): 
255-             The model or pipeline to set the parallelism context for. The model or pipeline must have been parallelized 
256-             with `.enable_parallelism(ParallelConfig(...), ...)` before using this context manager. 
257-     """ 
258- 
259-     from  diffusers  import  DiffusionPipeline , ModelMixin 
260- 
261-     from  .attention_dispatch  import  _AttentionBackendRegistry 
262- 
263-     global  _ENABLE_PARALLELISM_WARN_ONCE 
264-     if  not  _ENABLE_PARALLELISM_WARN_ONCE :
265-         logger .warning (
266-             "Support for `enable_parallelism` is experimental and the API may be subject to change in the future." 
267-         )
268-         _ENABLE_PARALLELISM_WARN_ONCE  =  True 
269- 
270-     if  isinstance (model_or_pipeline , DiffusionPipeline ):
271-         parallelized_components  =  [
272-             (name , component )
273-             for  name , component  in  model_or_pipeline .components .items ()
274-             if  getattr (component , "_parallel_config" , None ) is  not   None 
275-         ]
276-         if  len (parallelized_components ) >  1 :
277-             raise  ValueError (
278-                 "Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run " 
279-                 "different stages of the pipeline separately with `enable_parallelism` on each component manually." 
280-             )
281-         if  len (parallelized_components ) ==  0 :
282-             raise  ValueError (
283-                 "No parallelized components found in the pipeline. Please ensure at least one component is parallelized." 
284-             )
285-         _ , model_or_pipeline  =  parallelized_components [0 ]
286-     elif  isinstance (model_or_pipeline , ModelMixin ):
287-         if  getattr (model_or_pipeline , "_parallel_config" , None ) is  None :
288-             raise  ValueError (
289-                 "The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager." 
290-             )
291-     else :
292-         raise  TypeError (
293-             f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got { type (model_or_pipeline )}  . Please provide a valid model or pipeline." 
294-         )
295- 
296-     # TODO: needs to be updated when more parallelism strategies are supported 
297-     old_parallel_config  =  _AttentionBackendRegistry ._parallel_config 
298-     _AttentionBackendRegistry ._parallel_config  =  model_or_pipeline ._parallel_config .context_parallel_config 
299- 
300-     yield 
301- 
302-     _AttentionBackendRegistry ._parallel_config  =  old_parallel_config 
0 commit comments