4040
4141
4242@dataclass
43- class ParallelConfig :
43+ class ContextParallelConfig :
44+ # Number of GPUs to use for ring attention within a context parallel region
4445 ring_degree : Optional [int ] = None
46+ # Number of context parallel regions to use for ulysses attention within a context parallel region
4547 ulysses_degree : Optional [int ] = None
46-
47- def __post_init__ (self ):
48- if self .ring_degree is None :
49- self .ring_degree = 1
50- if self .ulysses_degree is None :
51- self .ulysses_degree = 1
52-
53-
54- @dataclass
55- class _InternalParallelConfig :
56- rank : int
57- world_size : int
58- ring_degree : int
59- ulysses_degree : int
60- device : torch .device
61- cp_mesh : torch .distributed .device_mesh .DeviceMesh
62-
6348 # Whether to convert output and LSE to float32 for ring attention numerical stability
6449 convert_to_fp32 : bool = True
6550 # TODO: support alltoall
6651 rotate_method : Literal ["allgather" , "alltoall" ] = "allgather"
6752
53+ _rank : int = None
54+ _world_size : int = None
55+ _device : torch .device = None
56+ _mesh : torch .distributed .device_mesh .DeviceMesh = None
6857 _flattened_mesh : torch .distributed .device_mesh .DeviceMesh = None
6958 _ring_mesh : torch .distributed .device_mesh .DeviceMesh = None
7059 _ulysses_mesh : torch .distributed .device_mesh .DeviceMesh = None
7160 _ring_local_rank : int = None
7261 _ulysses_local_rank : int = None
7362
7463 def __post_init__ (self ):
64+ if self .ring_degree is None :
65+ self .ring_degree = 1
66+ if self .ulysses_degree is None :
67+ self .ulysses_degree = 1
68+
69+ def setup (self , rank : int , world_size : int , device : torch .device , mesh : torch .distributed .device_mesh .DeviceMesh ):
70+ self ._rank = rank
71+ self ._world_size = world_size
72+ self ._device = device
73+ self ._mesh = mesh
74+ if self .ring_degree is None :
75+ self .ring_degree = 1
76+ if self .ulysses_degree is None :
77+ self .ulysses_degree = 1
7578 if self .rotate_method != "allgather" :
76- raise ValueError (f"Only rotate_method='allgather' is supported for now, but got { self .rotate_method } ." )
79+ raise NotImplementedError (
80+ f"Only rotate_method='allgather' is supported for now, but got { self .rotate_method } ."
81+ )
7782 if self ._flattened_mesh is None :
78- self ._flattened_mesh = self .cp_mesh ._flatten ()
83+ self ._flattened_mesh = self ._mesh ._flatten ()
7984 if self ._ring_mesh is None :
80- self ._ring_mesh = self .cp_mesh ["ring" ]
85+ self ._ring_mesh = self ._mesh ["ring" ]
8186 if self ._ulysses_mesh is None :
82- self ._ulysses_mesh = self .cp_mesh ["ulysses" ]
87+ self ._ulysses_mesh = self ._mesh ["ulysses" ]
8388 if self ._ring_local_rank is None :
8489 self ._ring_local_rank = self ._ring_mesh .get_local_rank ()
8590 if self ._ulysses_local_rank is None :
8691 self ._ulysses_local_rank = self ._ulysses_mesh .get_local_rank ()
8792
8893
94+ @dataclass
95+ class ParallelConfig :
96+ context_parallel_config : Optional [ContextParallelConfig ] = None
97+
98+ _rank : int = None
99+ _world_size : int = None
100+ _device : torch .device = None
101+ _cp_mesh : torch .distributed .device_mesh .DeviceMesh = None
102+
103+ def setup (
104+ self ,
105+ rank : int ,
106+ world_size : int ,
107+ device : torch .device ,
108+ * ,
109+ cp_mesh : Optional [torch .distributed .device_mesh .DeviceMesh ] = None ,
110+ ):
111+ self ._rank = rank
112+ self ._world_size = world_size
113+ self ._device = device
114+ self ._cp_mesh = cp_mesh
115+ if self .context_parallel_config is not None :
116+ self .context_parallel_config .setup (rank , world_size , device , cp_mesh )
117+
118+
89119@dataclass (frozen = True )
90120class ContextParallelInput :
91121 split_dim : int
@@ -145,7 +175,7 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
145175 parallelized_components = [
146176 (name , component )
147177 for name , component in model_or_pipeline .components .items ()
148- if getattr (component , "_internal_parallel_config " , None ) is not None
178+ if getattr (component , "_parallel_config " , None ) is not None
149179 ]
150180 if len (parallelized_components ) > 1 :
151181 raise ValueError (
@@ -158,7 +188,7 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
158188 )
159189 _ , model_or_pipeline = parallelized_components [0 ]
160190 elif isinstance (model_or_pipeline , ModelMixin ):
161- if getattr (model_or_pipeline , "_internal_parallel_config " , None ) is None :
191+ if getattr (model_or_pipeline , "_parallel_config " , None ) is None :
162192 raise ValueError (
163193 "The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
164194 )
@@ -167,8 +197,9 @@ def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin
167197 f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got { type (model_or_pipeline )} . Please provide a valid model or pipeline."
168198 )
169199
200+ # TODO: needs to be updated when more parallelism strategies are supported
170201 old_parallel_config = _AttentionBackendRegistry ._parallel_config
171- _AttentionBackendRegistry ._parallel_config = model_or_pipeline ._internal_parallel_config
202+ _AttentionBackendRegistry ._parallel_config = model_or_pipeline ._parallel_config . context_parallel_config
172203
173204 yield
174205
0 commit comments