11import ast
2- from typing import Any , List , Optional
2+ import functools
3+ from typing import Any , Dict , List , Optional
34
45
56class OrToBitOrTransformer (ast .NodeTransformer ):
@@ -19,10 +20,129 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
1920 return new_node
2021
2122
22- def _rewrite_bart_encoder_layer ():
23- "BartEncoderLayer, PLBartEncoderLayer"
23+ @functools .lru_cache
24+ def _rewrite_forward_clamp_float16 () -> Dict [str , List [type ]]:
25+
2426 import transformers
2527
28+ _known = {
29+ "AutoformerEncoderLayer" : [
30+ transformers .models .autoformer .modeling_autoformer .AutoformerEncoderLayer
31+ ],
32+ "BartEncoderLayer" : [
33+ transformers .models .bart .modeling_bart .BartEncoderLayer ,
34+ transformers .models .plbart .modeling_plbart .PLBartEncoderLayer ,
35+ ],
36+ "BigBirdPegasusEncoderLayer" : [
37+ transformers .models .bigbird_pegasus .modeling_bigbird_pegasus .BigBirdPegasusEncoderLayer
38+ ],
39+ "BlenderbotSmallEncoderLayer" : [
40+ transformers .models .blenderbot_small .modeling_blenderbot_small .BlenderbotSmallEncoderLayer
41+ ],
42+ "InformerEncoderLayer" : [
43+ transformers .models .informer .modeling_informer .InformerEncoderLayer
44+ ],
45+ "LEDEncoderLayer" : [transformers .models .led .modeling_led .LEDEncoderLayer ],
46+ "MarianEncoderLayer" : [transformers .models .marian .modeling_marian .MarianEncoderLayer ],
47+ "MvpEncoderLayer" : [transformers .models .mvp .modeling_mvp .MvpEncoderLayer ],
48+ "NllbMoeEncoderLayer" : [
49+ transformers .models .nllb_moe .modeling_nllb_moe .NllbMoeEncoderLayer
50+ ],
51+ "TimeSeriesTransformerEncoderLayer" : [
52+ transformers .models .time_series_transformer .modeling_time_series_transformer .TimeSeriesTransformerEncoderLayer
53+ ],
54+ }
55+ return _known
56+
57+
58+ @functools .lru_cache
59+ def known_transformers_rewritings_clamp_float16 () -> Dict [str , str ]:
60+ """
61+ This functions returns the list of known classes to be rewritten.
62+ in :epkg:`transformers`. Each class is mapped to an alias,
63+ this alias is then given to :func:`rewritings_transformers_clamp_float16`
64+ to rewrite the encoder layers because of a specific control flow.
65+
66+ .. runpython::
67+ :showcode:
68+
69+ import pprint
70+ from onnx_diagnostic.torch_export_patches.patch_model_helper import (
71+ known_transformers_rewritings,
72+ )
73+
74+ pprint.pprint(known_transformers_rewritings())
75+ """
76+ _alias = {
77+ "AutoformerEncoder" : "AutoformerEncoderLayer" ,
78+ "AutoformerEncoderLayer" : "AutoformerEncoderLayer" ,
79+ "AutoformerForPrediction" : "AutoformerEncoderLayer" ,
80+ "AutoformerModel" : "AutoformerEncoderLayer" ,
81+ "BartEncoderLayer" : "BartEncoderLayer" ,
82+ "BartForConditionalGeneration" : "BartEncoderLayer" ,
83+ "BigBirdPegasusForConditionalGeneration" : "BigBirdPegasusEncoderLayer" ,
84+ "BigBirdPegasusForQuestionAnswering" : "BigBirdPegasusEncoderLayer" ,
85+ "BigBirdPegasusForCausalLM" : "BigBirdPegasusEncoderLayer" ,
86+ "BlenderbotSmallEncoderLayer" : "BlenderbotSmallEncoderLayer" ,
87+ "BlenderbotSmallForConditionalGeneration" : "BlenderbotSmallEncoderLayer" ,
88+ "BlenderbotSmallForCausalLM" : "BlenderbotSmallEncoderLayer" ,
89+ "InformerEncoderLayer" : "InformerEncoderLayer" ,
90+ "InformerForPrediction" : "InformerEncoderLayer" ,
91+ "LEDEncoderLayer" : "LEDEncoderLayer" ,
92+ "LEDClassificationHead" : "LEDEncoderLayer" ,
93+ "LEDForConditionalGeneration" : "LEDEncoderLayer" ,
94+ "MarianEncoderLayer" : "MarianEncoderLayer" ,
95+ "MarianEncoder" : "MarianEncoderLayer" ,
96+ "MarianModel" : "MarianEncoderLayer" ,
97+ "MarianMTModel" : "MarianEncoderLayer" ,
98+ "MvpEncoderLayer" : "MvpEncoderLayer" ,
99+ "MvpPrompt" : "MvpEncoderLayer" ,
100+ "MvpForConditionalGeneration" : "MvpEncoderLayer" ,
101+ "MvpForSequenceClassification" : "MvpEncoderLayer" ,
102+ "MvpForQuestionAnswering" : "MvpEncoderLayer" ,
103+ "MvpForCausalLM" : "MvpEncoderLayer" ,
104+ "NllbMoeEncoderLayer" : "NllbMoeEncoderLayer" ,
105+ "NllbMoeForConditionalGeneration" : "NllbMoeEncoderLayer" ,
106+ "PLBartEncoderLayer" : "BartEncoderLayer" ,
107+ "PLBartForConditionalGeneration" : "BartEncoderLayer" ,
108+ "TimeSeriesTransformerEncoderLayer" : "TimeSeriesTransformerEncoderLayer" ,
109+ "TimeSeriesTransformerForPrediction" : "TimeSeriesTransformerEncoderLayer" ,
110+ }
111+ return _alias
112+
113+
114+ def rewritings_transformers_clamp_float16 (cls_name ) -> List [type ]:
115+ """
116+ Rewrites known control flows equal to this:
117+
118+ .. code-block:: python
119+
120+ if hidden_states.dtype == torch.float16 and (
121+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
122+ ):
123+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
124+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
125+
126+ *cls_name* is the class name. It is mapped with a list of other class names
127+ to rename. Here is the known list:
128+
129+ .. runpython::
130+ :showcode:
131+
132+ import pprint
133+ from onnx_diagnostic.torch_export_patches.patch_model_helper import (
134+ _rewrite_forward_clamp_float16,
135+ )
136+
137+ pprint.pprint(_rewrite_forward_clamp_float16()
138+
139+ Function :func:`known_transformers_rewritings` collects
140+ all model classes using those layers.
141+ """
142+ _known = _rewrite_forward_clamp_float16 ()
143+
144+ assert cls_name in _known , f"cls_name={ cls_name !r} unknown in { sorted (_known )} ."
145+
26146 bd = dict (
27147 filter_node = (
28148 lambda node : isinstance (node , ast .If ) and not isinstance (node .test , ast .Name )
@@ -35,16 +155,13 @@ def _add(f):
35155 g ["function" ] = f
36156 return g
37157
38- return [
39- _add (transformers .models .bart .modeling_bart .BartEncoderLayer .forward ),
40- _add (transformers .models .plbart .modeling_plbart .PLBartEncoderLayer .forward ),
41- ]
158+ return [_add (cls .forward ) for cls in _known [cls_name ]]
42159
43160
44161def code_needing_rewriting (cls_name : str ) -> Optional [List [Any ]]:
45162 """
46- Returns a known list of methods or functions to rewrite because of control flow
47- for a specific model class .
163+ Returns a known list of classes mapped to a known rewritings
164+ because of control flow. See :func:`registered_transformers_rewritings` .
48165
49166 :param cls_name: name of the class
50167 :return: a list of rewriting
@@ -59,11 +176,8 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
59176
60177 pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
61178 """
62- if cls_name in {
63- "BartEncoderLayer" ,
64- "BartForConditionalGeneration" ,
65- "PLBartEncoderLayer" ,
66- "PLBartForConditionalGeneration" ,
67- }:
68- return _rewrite_bart_encoder_layer ()
179+ aliases = known_transformers_rewritings_clamp_float16 ()
180+ if cls_name in aliases :
181+ alias = aliases [cls_name ]
182+ return rewritings_transformers_clamp_float16 (alias )
69183 return None
0 commit comments