Skip to content

Commit e4a0e5d

Browse files
vasqurgtjf
authored andcommitted
[FlexAttention] Reenable flex for encoder-decoder and make the test more robust (huggingface#38321)
* reenable most flex attention test cases * style * trigger * trigger
1 parent 69dfee1 commit e4a0e5d

26 files changed

+37
-55
lines changed

src/transformers/models/bart/modeling_bart.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,7 @@ class BartPreTrainedModel(PreTrainedModel):
494494
_skip_keys_device_placement = "past_key_values"
495495
_supports_flash_attn_2 = True
496496
_supports_sdpa = True
497-
# Compile issues
498-
_supports_flex_attn = False
497+
_supports_flex_attn = True
499498
_supports_cache_class = True
500499
_supports_static_cache = True
501500

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
348348
supports_gradient_checkpointing = True
349349
_supports_flash_attn_2 = True
350350
_supports_sdpa = True
351-
# Compile issues
352-
_supports_flex_attn = False
351+
_supports_flex_attn = True
353352
_supports_cache_class = True
354353
_supports_static_cache = True
355354

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
175175
supports_gradient_checkpointing = True
176176
_supports_flash_attn_2 = True
177177
_supports_sdpa = True
178-
# Compile issues
179-
_supports_flex_attn = False
178+
_supports_flex_attn = True
180179
_supports_cache_class = True
181180
_supports_static_cache = True
182181

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
464464
supports_gradient_checkpointing = True
465465
_supports_flash_attn_2 = True
466466
_supports_sdpa = True
467-
# Compile issues
468-
_supports_flex_attn = False
467+
_supports_flex_attn = True
469468
_supports_cache_class = True
470469
_supports_static_cache = True
471470

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
452452
supports_gradient_checkpointing = True
453453
_supports_flash_attn_2 = True
454454
_supports_sdpa = True
455-
# Compile issues
456-
_supports_flex_attn = False
455+
_supports_flex_attn = True
457456
_supports_cache_class = True
458457
_supports_static_cache = True
459458

src/transformers/models/data2vec/modeling_data2vec_audio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
551551
supports_gradient_checkpointing = True
552552
_supports_flash_attn_2 = True
553553
_supports_sdpa = True
554-
# Compile issues
555-
_supports_flex_attn = False
554+
_supports_flex_attn = True
556555

557556
def _init_weights(self, module):
558557
"""Initialize the weights"""

src/transformers/models/data2vec/modular_data2vec_audio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
140140
supports_gradient_checkpointing = True
141141
_supports_flash_attn_2 = True
142142
_supports_sdpa = True
143-
# Compile issues
144-
_supports_flex_attn = False
143+
_supports_flex_attn = True
145144

146145
def _init_weights(self, module):
147146
"""Initialize the weights"""

src/transformers/models/hubert/modeling_hubert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel):
738738
supports_gradient_checkpointing = True
739739
_supports_flash_attn_2 = True
740740
_supports_sdpa = True
741-
# Compile issues
742-
_supports_flex_attn = False
741+
_supports_flex_attn = True
743742

744743
def _init_weights(self, module):
745744
"""Initialize the weights"""

src/transformers/models/hubert/modular_hubert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel):
131131
supports_gradient_checkpointing = True
132132
_supports_flash_attn_2 = True
133133
_supports_sdpa = True
134-
# Compile issues
135-
_supports_flex_attn = False
134+
_supports_flex_attn = True
136135

137136
def _init_weights(self, module):
138137
"""Initialize the weights"""

src/transformers/models/m2m_100/modeling_m2m_100.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
530530
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
531531
_supports_flash_attn_2 = True
532532
_supports_sdpa = True
533-
# Compile issues
534-
_supports_flex_attn = False
533+
_supports_flex_attn = True
535534
_supports_cache_class = True
536535
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
537536
_supports_static_cache = False

0 commit comments

Comments
 (0)