Skip to content

Commit d316ba6

Browse files
authored
Extends rewrite list for cond (#114)
* extend rewrite list * improve rewritings * spell * fix * shrinks
1 parent a5c52b8 commit d316ba6

File tree

5 files changed

+153
-23
lines changed

5 files changed

+153
-23
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.6.1
55
+++++
66

7+
* :pr:`114`: extends the list of known rewritings
78
* :pr:`113`: fixes a couple of issues with ModelBuilder
89

910
0.6.0

onnx_diagnostic/helpers/helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,11 +1306,11 @@ def max_diff(
13061306
rdiff = diff / (exp_cpu.abs() + 1e-3)
13071307
if diff.numel() > 0:
13081308
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
1309-
float(diff.max()),
1310-
float(rdiff.max()),
1311-
float(diff.sum()),
1309+
float(diff.max().detach()),
1310+
float(rdiff.max().detach()),
1311+
float(diff.sum().detach()),
13121312
float(diff.numel()),
1313-
float(ndiff.sum()),
1313+
float(ndiff.sum().detach()),
13141314
)
13151315
argm = tuple(map(int, torch.unravel_index(diff.argmax(), diff.shape)))
13161316
elif got_cpu.numel() == exp_cpu.numel():

onnx_diagnostic/tasks/image_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def get_inputs(
5858
shapes = {
5959
"pixel_values": {
6060
0: torch.export.Dim("batch", min=1, max=1024),
61-
2: torch.export.Dim("width", min=1, max=4096),
62-
3: torch.export.Dim("height", min=1, max=4096),
61+
2: "width",
62+
3: "height",
6363
},
6464
}
6565
inputs = dict(

onnx_diagnostic/torch_export_patches/patch_module_helper.py

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
2-
from typing import Any, List, Optional
2+
import functools
3+
from typing import Any, Dict, List, Optional
34

45

56
class OrToBitOrTransformer(ast.NodeTransformer):
@@ -19,10 +20,129 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
1920
return new_node
2021

2122

22-
def _rewrite_bart_encoder_layer():
23-
"BartEncoderLayer, PLBartEncoderLayer"
23+
@functools.lru_cache
24+
def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]:
25+
2426
import transformers
2527

28+
_known = {
29+
"AutoformerEncoderLayer": [
30+
transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer
31+
],
32+
"BartEncoderLayer": [
33+
transformers.models.bart.modeling_bart.BartEncoderLayer,
34+
transformers.models.plbart.modeling_plbart.PLBartEncoderLayer,
35+
],
36+
"BigBirdPegasusEncoderLayer": [
37+
transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer
38+
],
39+
"BlenderbotSmallEncoderLayer": [
40+
transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer
41+
],
42+
"InformerEncoderLayer": [
43+
transformers.models.informer.modeling_informer.InformerEncoderLayer
44+
],
45+
"LEDEncoderLayer": [transformers.models.led.modeling_led.LEDEncoderLayer],
46+
"MarianEncoderLayer": [transformers.models.marian.modeling_marian.MarianEncoderLayer],
47+
"MvpEncoderLayer": [transformers.models.mvp.modeling_mvp.MvpEncoderLayer],
48+
"NllbMoeEncoderLayer": [
49+
transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer
50+
],
51+
"TimeSeriesTransformerEncoderLayer": [
52+
transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer
53+
],
54+
}
55+
return _known
56+
57+
58+
@functools.lru_cache
59+
def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
60+
"""
61+
This functions returns the list of known classes to be rewritten.
62+
in :epkg:`transformers`. Each class is mapped to an alias,
63+
this alias is then given to :func:`rewritings_transformers_clamp_float16`
64+
to rewrite the encoder layers because of a specific control flow.
65+
66+
.. runpython::
67+
:showcode:
68+
69+
import pprint
70+
from onnx_diagnostic.torch_export_patches.patch_model_helper import (
71+
known_transformers_rewritings,
72+
)
73+
74+
pprint.pprint(known_transformers_rewritings())
75+
"""
76+
_alias = {
77+
"AutoformerEncoder": "AutoformerEncoderLayer",
78+
"AutoformerEncoderLayer": "AutoformerEncoderLayer",
79+
"AutoformerForPrediction": "AutoformerEncoderLayer",
80+
"AutoformerModel": "AutoformerEncoderLayer",
81+
"BartEncoderLayer": "BartEncoderLayer",
82+
"BartForConditionalGeneration": "BartEncoderLayer",
83+
"BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
84+
"BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
85+
"BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
86+
"BlenderbotSmallEncoderLayer": "BlenderbotSmallEncoderLayer",
87+
"BlenderbotSmallForConditionalGeneration": "BlenderbotSmallEncoderLayer",
88+
"BlenderbotSmallForCausalLM": "BlenderbotSmallEncoderLayer",
89+
"InformerEncoderLayer": "InformerEncoderLayer",
90+
"InformerForPrediction": "InformerEncoderLayer",
91+
"LEDEncoderLayer": "LEDEncoderLayer",
92+
"LEDClassificationHead": "LEDEncoderLayer",
93+
"LEDForConditionalGeneration": "LEDEncoderLayer",
94+
"MarianEncoderLayer": "MarianEncoderLayer",
95+
"MarianEncoder": "MarianEncoderLayer",
96+
"MarianModel": "MarianEncoderLayer",
97+
"MarianMTModel": "MarianEncoderLayer",
98+
"MvpEncoderLayer": "MvpEncoderLayer",
99+
"MvpPrompt": "MvpEncoderLayer",
100+
"MvpForConditionalGeneration": "MvpEncoderLayer",
101+
"MvpForSequenceClassification": "MvpEncoderLayer",
102+
"MvpForQuestionAnswering": "MvpEncoderLayer",
103+
"MvpForCausalLM": "MvpEncoderLayer",
104+
"NllbMoeEncoderLayer": "NllbMoeEncoderLayer",
105+
"NllbMoeForConditionalGeneration": "NllbMoeEncoderLayer",
106+
"PLBartEncoderLayer": "BartEncoderLayer",
107+
"PLBartForConditionalGeneration": "BartEncoderLayer",
108+
"TimeSeriesTransformerEncoderLayer": "TimeSeriesTransformerEncoderLayer",
109+
"TimeSeriesTransformerForPrediction": "TimeSeriesTransformerEncoderLayer",
110+
}
111+
return _alias
112+
113+
114+
def rewritings_transformers_clamp_float16(cls_name) -> List[type]:
115+
"""
116+
Rewrites known control flows equal to this:
117+
118+
.. code-block:: python
119+
120+
if hidden_states.dtype == torch.float16 and (
121+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
122+
):
123+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
124+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
125+
126+
*cls_name* is the class name. It is mapped with a list of other class names
127+
to rename. Here is the known list:
128+
129+
.. runpython::
130+
:showcode:
131+
132+
import pprint
133+
from onnx_diagnostic.torch_export_patches.patch_model_helper import (
134+
_rewrite_forward_clamp_float16,
135+
)
136+
137+
pprint.pprint(_rewrite_forward_clamp_float16()
138+
139+
Function :func:`known_transformers_rewritings` collects
140+
all model classes using those layers.
141+
"""
142+
_known = _rewrite_forward_clamp_float16()
143+
144+
assert cls_name in _known, f"cls_name={cls_name!r} unknown in {sorted(_known)}."
145+
26146
bd = dict(
27147
filter_node=(
28148
lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
@@ -35,16 +155,13 @@ def _add(f):
35155
g["function"] = f
36156
return g
37157

38-
return [
39-
_add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
40-
_add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
41-
]
158+
return [_add(cls.forward) for cls in _known[cls_name]]
42159

43160

44161
def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
45162
"""
46-
Returns a known list of methods or functions to rewrite because of control flow
47-
for a specific model class.
163+
Returns a known list of classes mapped to a known rewritings
164+
because of control flow. See :func:`registered_transformers_rewritings`.
48165
49166
:param cls_name: name of the class
50167
:return: a list of rewriting
@@ -59,11 +176,8 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
59176
60177
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
61178
"""
62-
if cls_name in {
63-
"BartEncoderLayer",
64-
"BartForConditionalGeneration",
65-
"PLBartEncoderLayer",
66-
"PLBartForConditionalGeneration",
67-
}:
68-
return _rewrite_bart_encoder_layer()
179+
aliases = known_transformers_rewritings_clamp_float16()
180+
if cls_name in aliases:
181+
alias = aliases[cls_name]
182+
return rewritings_transformers_clamp_float16(alias)
69183
return None

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,19 @@ def _quiet_or_not_quiet(
209209
return res
210210

211211

212+
def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
213+
"""Shrinks the configuration before it gets added to the information to log."""
214+
new_cfg = {}
215+
for k, v in cfg.items():
216+
217+
new_cfg[k] = (
218+
v
219+
if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
220+
else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
221+
)
222+
return new_cfg
223+
224+
212225
def validate_model(
213226
model_id: str,
214227
task: Optional[str] = None,
@@ -436,7 +449,9 @@ def validate_model(
436449
if summary["model_module"] in sys.modules:
437450
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
438451
summary["model_config_class"] = data["configuration"].__class__.__name__
439-
summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
452+
summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace(
453+
" ", ""
454+
)
440455
summary["model_id"] = model_id
441456

442457
if verbose:

0 commit comments

Comments
 (0)