1515import math
1616from typing import Any , Dict , Optional , Tuple , Union
1717
18- from typing import Optional
19-
2018import torch
2119import torch .nn as nn
2220import torch .nn .functional as F
@@ -40,9 +38,8 @@ class Magi1AttnProcessor2_0:
4038 r"""
4139 Processor for implementing MAGI-1 attention mechanism.
4240
43- This processor handles both self-attention and cross-attention for the MAGI-1 model,
44- following diffusers' standard attention processor interface. It supports image conditioning
45- for image-to-video generation tasks.
41+ This processor handles both self-attention and cross-attention for the MAGI-1 model, following diffusers' standard
42+ attention processor interface. It supports image conditioning for image-to-video generation tasks.
4643 """
4744
4845 def __init__ (self ):
@@ -62,7 +59,7 @@ def __call__(
6259 if attn .add_k_proj is not None and encoder_hidden_states is not None :
6360 # Extract image conditioning from the concatenated encoder states
6461 # The text encoder context length is typically 512 tokens
65- text_context_length = getattr (attn , ' text_context_length' , 512 )
62+ text_context_length = getattr (attn , " text_context_length" , 512 )
6663 image_context_length = encoder_hidden_states .shape [1 ] - text_context_length
6764 encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
6865 encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
@@ -89,6 +86,7 @@ def __call__(
8986
9087 # Apply rotary embeddings if provided
9188 if rotary_emb is not None :
89+
9290 def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
9391 dtype = torch .float32 if hidden_states .device .type == "mps" else torch .float64
9492 x_rotated = torch .view_as_complex (hidden_states .to (dtype ).unflatten (3 , (- 1 , 2 )))
@@ -129,16 +127,17 @@ class Magi1ImageEmbedding(torch.nn.Module):
129127 """
130128 Image embedding layer for the MAGI-1 model.
131129
132- This module processes image conditioning features for image-to-video generation tasks.
133- It applies layer normalization, a feed-forward transformation, and optional positional
134- embeddings to prepare image features for cross-attention.
130+ This module processes image conditioning features for image-to-video generation tasks. It applies layer
131+ normalization, a feed-forward transformation, and optional positional embeddings to prepare image features for
132+ cross-attention.
135133
136134 Args:
137135 in_features (`int`): Input feature dimension.
138136 out_features (`int`): Output feature dimension.
139137 pos_embed_seq_len (`int`, optional): Sequence length for positional embeddings.
140138 If provided, learnable positional embeddings will be added to the input.
141139 """
140+
142141 def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
143142 super ().__init__ ()
144143
@@ -179,6 +178,7 @@ class Magi1TimeTextImageEmbedding(nn.Module):
179178 image_embed_dim (`int`, optional): Input dimension of image embeddings.
180179 pos_embed_seq_len (`int`, optional): Sequence length for image positional embeddings.
181180 """
181+
182182 def __init__ (
183183 self ,
184184 dim : int ,
@@ -269,9 +269,9 @@ class Magi1TransformerBlock(nn.Module):
269269 """
270270 A transformer block used in the MAGI-1 model.
271271
272- This block follows diffusers' design philosophy with separate self-attention (attn1)
273- and cross-attention (attn2) modules, while faithfully implementing the original
274- MAGI-1 logic through appropriate parameter mapping during conversion.
272+ This block follows diffusers' design philosophy with separate self-attention (attn1) and cross-attention (attn2)
273+ modules, while faithfully implementing the original MAGI-1 logic through appropriate parameter mapping during
274+ conversion.
275275
276276 Args:
277277 dim (`int`): The number of channels in the input and output.
@@ -369,9 +369,9 @@ class Magi1Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
369369 r"""
370370 A Transformer model for video-like data used in the Magi1 model.
371371
372- This model implements a 3D transformer architecture for video generation with support for text conditioning
373- and optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization
374- for temporal and spatial modeling.
372+ This model implements a 3D transformer architecture for video generation with support for text conditioning and
373+ optional image conditioning. The model uses rotary position embeddings and adaptive layer normalization for
374+ temporal and spatial modeling.
375375
376376 Args:
377377 patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
@@ -542,11 +542,7 @@ def forward(
542542 # Patchify: (B, C, T, H, W) -> (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w)
543543 hidden_states = hidden_states .unfold (2 , p_t , p_t ).unfold (3 , p_h , p_h ).unfold (4 , p_w , p_w )
544544 hidden_states = hidden_states .contiguous ().view (
545- batch_size ,
546- num_frames // p_t ,
547- height // p_h ,
548- width // p_w ,
549- num_channels * p_t * p_h * p_w
545+ batch_size , num_frames // p_t , height // p_h , width // p_w , num_channels * p_t * p_h * p_w
550546 )
551547 # Reshape to sequence: (B, T*H*W, C*p_t*p_h*p_w)
552548 hidden_states = hidden_states .flatten (1 , 3 )
@@ -595,15 +591,22 @@ def forward(
595591 # Rearrange patches: (B, T//p_t, H//p_h, W//p_w, C*p_t*p_h*p_w) -> (B, C, T, H, W)
596592 p_t , p_h , p_w = self .config .patch_size
597593 hidden_states = hidden_states .view (
598- batch_size , post_patch_num_frames , post_patch_height , post_patch_width ,
599- self .config .out_channels , p_t , p_h , p_w
594+ batch_size ,
595+ post_patch_num_frames ,
596+ post_patch_height ,
597+ post_patch_width ,
598+ self .config .out_channels ,
599+ p_t ,
600+ p_h ,
601+ p_w ,
600602 )
601603 hidden_states = hidden_states .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 )
602604 output = hidden_states .contiguous ().view (
603- batch_size , self .config .out_channels ,
605+ batch_size ,
606+ self .config .out_channels ,
604607 post_patch_num_frames * p_t ,
605608 post_patch_height * p_h ,
606- post_patch_width * p_w
609+ post_patch_width * p_w ,
607610 )
608611
609612 if USE_PEFT_BACKEND :
0 commit comments