Skip to content

Commit 58dc666

Browse files
committed
style
1 parent 6027704 commit 58dc666

File tree

6 files changed

+169
-109
lines changed

6 files changed

+169
-109
lines changed

scripts/convert_magi1_to_diffusers.py

Lines changed: 135 additions & 79 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@
185185
"Kandinsky3UNet",
186186
"LatteTransformer3DModel",
187187
"LTXVideoTransformer3DModel",
188-
"Magi1Transformer3DModel",
189188
"Lumina2Transformer2DModel",
190189
"LuminaNextDiT2DModel",
190+
"Magi1Transformer3DModel",
191191
"MochiTransformer3DModel",
192192
"ModelMixin",
193193
"MotionAdapter",
@@ -805,9 +805,9 @@
805805
Kandinsky3UNet,
806806
LatteTransformer3DModel,
807807
LTXVideoTransformer3DModel,
808-
Magi1Transformer3DModel,
809808
Lumina2Transformer2DModel,
810809
LuminaNextDiT2DModel,
810+
Magi1Transformer3DModel,
811811
MochiTransformer3DModel,
812812
ModelMixin,
813813
MotionAdapter,

src/diffusers/models/transformers/transformer_magi1.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import math
1616
from typing import Any, Dict, Optional, Tuple, Union
1717

18-
from typing import Optional
19-
2018
import torch
2119
import torch.nn as nn
2220
import torch.nn.functional as F
@@ -40,9 +38,8 @@ class Magi1AttnProcessor2_0:
4038
r"""
4139
Processor for implementing MAGI-1 attention mechanism.
4240
43-
This processor handles both self-attention and cross-attention for the MAGI-1 model,
44-
following diffusers' standard attention processor interface. It supports image conditioning
45-
for image-to-video generation tasks.
41+
This processor handles both self-attention and cross-attention for the MAGI-1 model, following diffusers' standard
42+
attention processor interface. It supports image conditioning for image-to-video generation tasks.
4643
"""
4744

4845
def __init__(self):
@@ -62,7 +59,7 @@ def __call__(
6259
if attn.add_k_proj is not None and encoder_hidden_states is not None:
6360
# Extract image conditioning from the concatenated encoder states
6461
# The text encoder context length is typically 512 tokens
65-
text_context_length = getattr(attn, 'text_context_length', 512)
62+
text_context_length = getattr(attn, "text_context_length", 512)
6663
image_context_length = encoder_hidden_states.shape[1] - text_context_length
6764
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
6865
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
@@ -89,6 +86,7 @@ def __call__(
8986

9087
# Apply rotary embeddings if provided
9188
if rotary_emb is not None:
89+
9290
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
9391
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
9492
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
@@ -129,16 +127,17 @@ class Magi1ImageEmbedding(torch.nn.Module):
129127
"""
130128
Image embedding layer for the MAGI-1 model.
131129
132-
This module processes image conditioning features for image-to-video generation tasks.
133-
It applies layer normalization, a feed-forward transformation, and optional positional
134-
embeddings to prepare image features for cross-attention.
130+
This module processes image conditioning features for image-to-video generation tasks. It applies layer
131+
normalization, a feed-forward transformation, and optional positional embeddings to prepare image features for
132+
cross-attention.
135133
136134
Args:
137135
in_features (`int`): Input feature dimension.
138136
out_features (`int`): Output feature dimension.
139137
pos_embed_seq_len (`int`, optional): Sequence length for positional embeddings.
140138
If provided, learnable positional embeddings will be added to the input.
141139
"""
140+
142141
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
143142
super().__init__()
144143

@@ -179,6 +178,7 @@ class Magi1TimeTextImageEmbedding(nn.Module):
179178
image_embed_dim (`int`, optional): Input dimension of image embeddings.
180179
pos_embed_seq_len (`int`, optional): Sequence length for image positional embeddings.
181180
"""
181+
182182
def __init__(
183183
self,
184184
dim: int,
@@ -269,9 +269,9 @@ class Magi1TransformerBlock(nn.Module):
269269
"""
270270
A transformer block used in the MAGI-1 model.
271271
272-
This block follows diffusers' design philosophy with separate self-attention (attn1)
273-
and cross-attention (attn2) modules, while faithfully implementing the original
274-
MAGI-1 logic through appropriate parameter mapping during conversion.
272+
This block follows diffusers' design philosophy with separate self-attention (attn1) and cross-attention (attn2)
273+
modules, while faithfully implementing the original MAGI-1 logic through appropriate parameter mapping during
274+
conversion.
275275
276276
Args:
277277
dim (`int`): The number of channels in the input and output.
@@ -369,9 +369,9 @@ class Magi1Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
369369
r"""
370370
A Transformer model for video-like data used in the Magi1 model.
371371
372-
This model implements a 3D transformer architecture for video generation with support for text conditioning
373-
and optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization
374-
for temporal and spatial modeling.
372+
This model implements a 3D transformer architecture for video generation with support for text conditioning and
373+
optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization for
374+
temporal and spatial modeling.
375375
376376
Args:
377377
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
@@ -542,11 +542,7 @@ def forward(
542542
# Patchify: (B, C, T, H, W) -> (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w)
543543
hidden_states = hidden_states.unfold(2, p_t, p_t).unfold(3, p_h, p_h).unfold(4, p_w, p_w)
544544
hidden_states = hidden_states.contiguous().view(
545-
batch_size,
546-
num_frames // p_t,
547-
height // p_h,
548-
width // p_w,
549-
num_channels * p_t * p_h * p_w
545+
batch_size, num_frames // p_t, height // p_h, width // p_w, num_channels * p_t * p_h * p_w
550546
)
551547
# Reshape to sequence: (B, T*H*W, C*p_t*p_h*p_w)
552548
hidden_states = hidden_states.flatten(1, 3)
@@ -595,15 +591,22 @@ def forward(
595591
# Rearrange patches: (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w) -> (B, C, T, H, W)
596592
p_t, p_h, p_w = self.config.patch_size
597593
hidden_states = hidden_states.view(
598-
batch_size, post_patch_num_frames, post_patch_height, post_patch_width,
599-
self.config.out_channels, p_t, p_h, p_w
594+
batch_size,
595+
post_patch_num_frames,
596+
post_patch_height,
597+
post_patch_width,
598+
self.config.out_channels,
599+
p_t,
600+
p_h,
601+
p_w,
600602
)
601603
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
602604
output = hidden_states.contiguous().view(
603-
batch_size, self.config.out_channels,
605+
batch_size,
606+
self.config.out_channels,
604607
post_patch_num_frames * p_t,
605608
post_patch_height * p_h,
606-
post_patch_width * p_w
609+
post_patch_width * p_w,
607610
)
608611

609612
if USE_PEFT_BACKEND:

src/diffusers/pipelines/magi1/pipeline_magi1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from transformers import AutoTokenizer, UMT5EncoderModel
2222

2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24-
#from ...loaders import Magi1LoraLoaderMixin
24+
25+
# from ...loaders import Magi1LoraLoaderMixin
2526
from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel
2627
from ...schedulers import FlowMatchEulerDiscreteScheduler
2728
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
@@ -91,7 +92,7 @@ def prompt_clean(text):
9192
return text
9293

9394

94-
class Magi1Pipeline(DiffusionPipeline):#, Magi1LoraLoaderMixin):
95+
class Magi1Pipeline(DiffusionPipeline): # , Magi1LoraLoaderMixin):
9596
r"""
9697
Pipeline for text-to-video generation using Magi1.
9798

tests/pipelines/magi1/test_magi1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from transformers import AutoTokenizer, T5EncoderModel
2121

22-
from diffusers import AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler, Magi1Transformer3DModel, Magi1Pipeline
22+
from diffusers import AutoencoderKLMagi1, FlowMatchEulerDiscreteScheduler, Magi1Pipeline, Magi1Transformer3DModel
2323
from diffusers.utils.testing_utils import (
2424
backend_empty_cache,
2525
enable_full_determinism,

tests/pipelines/magi1/test_magi1_image_to_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from diffusers import (
2323
AutoencoderKLMagi1,
2424
FlowMatchEulerDiscreteScheduler,
25-
Magi1Transformer3DModel,
2625
Magi1ImageToVideoPipeline,
26+
Magi1Transformer3DModel,
2727
)
2828
from diffusers.utils.testing_utils import (
2929
enable_full_determinism,

0 commit comments

Comments
 (0)