Skip to content

Commit d0e32a1

Browse files
committed
extend rewrite list
1 parent a5c52b8 commit d0e32a1

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,39 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
6060
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
6161
"""
6262
if cls_name in {
63+
"AutoformerEncoderLayer",
64+
"AutoformerEncoder",
65+
"AutoformerForPrediction",
6366
"BartEncoderLayer",
67+
"AutoformerModel",
6468
"BartForConditionalGeneration",
69+
"BigBirdPegasusEncoderLayer",
70+
"BigBirdPegasusForConditionalGeneration",
71+
"BigBirdPegasusForQuestionAnswering",
72+
"BigBirdPegasusForCausalLM",
73+
"BlenderbotSmallEncoderLayer",
74+
"BlenderbotSmallForConditionalGeneration",
75+
"BlenderbotSmallForCausalLM",
76+
"InformerEncoderLayer",
77+
"InformerForPrediction",
78+
"LEDEncoderLayer",
79+
"LEDClassificationHead",
80+
"LEDForConditionalGeneration",
81+
"MarianEncoderLayer",
82+
"MarianEncoder",
83+
"MarianModel",
84+
"MvpEncoderLayer",
85+
"MvpPrompt",
86+
"MvpForConditionalGeneration",
87+
"MvpForSequenceClassification",
88+
"MvpForQuestionAnswering",
89+
"MvpForCausalLM",
90+
"NllbMoeEncoderLayer",
91+
"NllbMoeForConditionalGeneration",
6592
"PLBartEncoderLayer",
6693
"PLBartForConditionalGeneration",
94+
"TimeSeriesTransformerEncoderLayer",
95+
"TimeSeriesTransformerForPrediction",
6796
}:
6897
return _rewrite_bart_encoder_layer()
6998
return None

0 commit comments

Comments
 (0)