@@ -60,10 +60,39 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
6060 pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
6161 """
6262 if cls_name in {
63+ "AutoformerEncoderLayer" ,
64+ "AutoformerEncoder" ,
65+ "AutoformerForPrediction" ,
6366 "BartEncoderLayer" ,
67+ "AutoformerModel" ,
6468 "BartForConditionalGeneration" ,
69+ "BigBirdPegasusEncoderLayer" ,
70+ "BigBirdPegasusForConditionalGeneration" ,
71+ "BigBirdPegasusForQuestionAnswering" ,
72+ "BigBirdPegasusForCausalLM" ,
73+ "BlenderbotSmallEncoderLayer" ,
74+ "BlenderbotSmallForConditionalGeneration" ,
75+ "BlenderbotSmallForCausalLM" ,
76+ "InformerEncoderLayer" ,
77+ "InformerForPrediction" ,
78+ "LEDEncoderLayer" ,
79+ "LEDClassificationHead" ,
80+ "LEDForConditionalGeneration" ,
81+ "MarianEncoderLayer" ,
82+ "MarianEncoder" ,
83+ "MarianModel" ,
84+ "MvpEncoderLayer" ,
85+ "MvpPrompt" ,
86+ "MvpForConditionalGeneration" ,
87+ "MvpForSequenceClassification" ,
88+ "MvpForQuestionAnswering" ,
89+ "MvpForCausalLM" ,
90+ "NllbMoeEncoderLayer" ,
91+ "NllbMoeForConditionalGeneration" ,
6592 "PLBartEncoderLayer" ,
6693 "PLBartForConditionalGeneration" ,
94+ "TimeSeriesTransformerEncoderLayer" ,
95+ "TimeSeriesTransformerForPrediction" ,
6796 }:
6897 return _rewrite_bart_encoder_layer ()
6998 return None
0 commit comments