@@ -250,15 +250,21 @@ def forward(
250250 hidden_states : torch .Tensor ,
251251 temb : torch .Tensor ,
252252 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
253+ attention_mask : Optional [torch .Tensor ] = None ,
253254 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
254255 ) -> torch .Tensor :
255256 residual = hidden_states
256257 norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
257258 mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
258259 joint_attention_kwargs = joint_attention_kwargs or {}
260+
261+ if attention_mask is not None :
262+ attention_mask = attention_mask [:, None , None , :] * attention_mask [:, None , :, None ]
263+
259264 attn_output = self .attn (
260265 hidden_states = norm_hidden_states ,
261266 image_rotary_emb = image_rotary_emb ,
267+ attention_mask = attention_mask ,
262268 ** joint_attention_kwargs ,
263269 )
264270
@@ -312,6 +318,7 @@ def forward(
312318 encoder_hidden_states : torch .Tensor ,
313319 temb : torch .Tensor ,
314320 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
321+ attention_mask : Optional [torch .Tensor ] = None ,
315322 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
316323 ) -> Tuple [torch .Tensor , torch .Tensor ]:
317324 temb_img , temb_txt = temb [:, :6 ], temb [:, 6 :]
@@ -321,11 +328,15 @@ def forward(
321328 encoder_hidden_states , emb = temb_txt
322329 )
323330 joint_attention_kwargs = joint_attention_kwargs or {}
331+ if attention_mask is not None :
332+ attention_mask = attention_mask [:, None , None , :] * attention_mask [:, None , :, None ]
333+
324334 # Attention.
325335 attention_outputs = self .attn (
326336 hidden_states = norm_hidden_states ,
327337 encoder_hidden_states = norm_encoder_hidden_states ,
328338 image_rotary_emb = image_rotary_emb ,
339+ attention_mask = attention_mask ,
329340 ** joint_attention_kwargs ,
330341 )
331342
@@ -570,6 +581,7 @@ def forward(
570581 timestep : torch .LongTensor = None ,
571582 img_ids : torch .Tensor = None ,
572583 txt_ids : torch .Tensor = None ,
584+ attention_mask : torch .Tensor = None ,
573585 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
574586 controlnet_block_samples = None ,
575587 controlnet_single_block_samples = None ,
@@ -659,11 +671,7 @@ def forward(
659671 )
660672 if torch .is_grad_enabled () and self .gradient_checkpointing :
661673 encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
662- block ,
663- hidden_states ,
664- encoder_hidden_states ,
665- temb ,
666- image_rotary_emb ,
674+ block , hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask
667675 )
668676
669677 else :
@@ -672,6 +680,7 @@ def forward(
672680 encoder_hidden_states = encoder_hidden_states ,
673681 temb = temb ,
674682 image_rotary_emb = image_rotary_emb ,
683+ attention_mask = attention_mask ,
675684 joint_attention_kwargs = joint_attention_kwargs ,
676685 )
677686
@@ -704,6 +713,7 @@ def forward(
704713 hidden_states = hidden_states ,
705714 temb = temb ,
706715 image_rotary_emb = image_rotary_emb ,
716+ attention_mask = attention_mask ,
707717 joint_attention_kwargs = joint_attention_kwargs ,
708718 )
709719
0 commit comments