Skip to content

Commit 2d29f94

Browse files
committed
Refactor: Rename Magi1AttnProcessor2_0 and Magi1TransformerBlock classes for clarity
Refactor: Rename VAE components for clarity Renames the attention processor and transformer block classes to be more specific to the VAE architecture. This improves code readability by making the purpose of these components more explicit.
1 parent 7415473 commit 2d29f94

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_magi1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3535

3636

37-
class Magi1AttnProcessor2_0:
37+
class Magi1VAEAttnProcessor2_0:
3838
def __init__(self, dim, num_heads=8):
3939
if not hasattr(F, "scaled_dot_product_attention"):
4040
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
@@ -80,7 +80,7 @@ def __call__(
8080
return hidden_states + identity
8181

8282

83-
class Magi1TransformerBlock(nn.Module):
83+
class Magi1VAETransformerBlock(nn.Module):
8484
def __init__(
8585
self,
8686
dim: int,
@@ -99,7 +99,7 @@ def __init__(
9999
bias=True,
100100
cross_attention_dim=None,
101101
out_bias=True,
102-
processor=Magi1AttnProcessor2_0(dim, num_heads),
102+
processor=Magi1VAEAttnProcessor2_0(dim, num_heads),
103103
)
104104

105105
self.drop_path = nn.Identity()
@@ -154,7 +154,7 @@ def __init__(
154154
# 3. Transformer blocks
155155
self.blocks = nn.ModuleList(
156156
[
157-
Magi1TransformerBlock(
157+
Magi1VAETransformerBlock(
158158
inner_dim,
159159
num_attention_heads,
160160
ffn_dim,
@@ -241,7 +241,7 @@ def __init__(
241241
# 3. Transformer blocks
242242
self.blocks = nn.ModuleList(
243243
[
244-
Magi1TransformerBlock(
244+
Magi1VAETransformerBlock(
245245
inner_dim,
246246
num_attention_heads,
247247
ffn_dim,
@@ -317,7 +317,7 @@ class AutoencoderKLMagi1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
317317

318318
_supports_gradient_checkpointing = False
319319
_skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
320-
_no_split_modules = ["Magi1TransformerBlock"]
320+
_no_split_modules = ["Magi1VAETransformerBlock"]
321321
_keep_in_fp32_modules = ["qkv_norm", "norm1", "norm2"]
322322
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
323323

0 commit comments

Comments
 (0)