Skip to content

Commit a5d5ef4

Browse files
committed
make style
1 parent 498b191 commit a5d5ef4

File tree

4 files changed

+21
-10
lines changed

4 files changed

+21
-10
lines changed

src/diffusers/models/_modeling_parallel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
import contextlib
1918
from dataclasses import dataclass
2019
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
2120

@@ -25,8 +24,7 @@
2524

2625

2726
if TYPE_CHECKING:
28-
from ..pipelines.pipeline_utils import DiffusionPipeline
29-
from .modeling_utils import ModelMixin
27+
pass
3028

3129

3230
logger = get_logger(__name__) # pylint: disable=invalid-name

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,13 @@ def list_backends(cls):
235235
return list(cls._backends.keys())
236236

237237
@classmethod
238-
def _is_context_parallel_enabled(cls, backend: AttentionBackendName, parallel_config: Optional["ContextParallelConfig"]) -> bool:
238+
def _is_context_parallel_enabled(
239+
cls, backend: AttentionBackendName, parallel_config: Optional["ContextParallelConfig"]
240+
) -> bool:
239241
supports_context_parallel = backend in cls._supports_context_parallel
240-
is_degree_greater_than_1 = parallel_config is not None and (parallel_config.ring_degree > 1 or parallel_config.ulysses_degree > 1)
242+
is_degree_greater_than_1 = parallel_config is not None and (
243+
parallel_config.ring_degree > 1 or parallel_config.ulysses_degree > 1
244+
)
241245
return supports_context_parallel and is_degree_greater_than_1
242246

243247

@@ -285,9 +289,8 @@ def dispatch_attention_fn(
285289
backend_name = AttentionBackendName(backend)
286290
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
287291

288-
if (
289-
parallel_config is not None
290-
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name, parallel_config)
292+
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
293+
backend_name, parallel_config
291294
):
292295
raise ValueError(
293296
f"Backend {backend_name} either does not support context parallelism or context parallelism "

src/diffusers/models/transformers/transformer_bria.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,12 @@ def __call__(
162162
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
163163

164164
hidden_states = dispatch_attention_fn(
165-
query, key, value, attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config
165+
query,
166+
key,
167+
value,
168+
attn_mask=attention_mask,
169+
backend=self._attention_backend,
170+
parallel_config=self._parallel_config,
166171
)
167172
hidden_states = hidden_states.flatten(2, 3)
168173
hidden_states = hidden_states.to(query.dtype)

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,12 @@ def __call__(
116116
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
117117

118118
hidden_states = dispatch_attention_fn(
119-
query, key, value, attn_mask=attention_mask, backend=self._attention_backend, parallel_config=self._parallel_config
119+
query,
120+
key,
121+
value,
122+
attn_mask=attention_mask,
123+
backend=self._attention_backend,
124+
parallel_config=self._parallel_config,
120125
)
121126
hidden_states = hidden_states.flatten(2, 3)
122127
hidden_states = hidden_states.to(query.dtype)

0 commit comments

Comments
 (0)