Skip to content

Commit 3ddec0a

Browse files
committed
fix a few things
1 parent a13e624 commit 3ddec0a

File tree

7 files changed

+121
-10
lines changed

7 files changed

+121
-10
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.torch_models.hghub.hub_data import code_needing_rewriting
4+
5+
6+
class TestHuggingFaceHubModelRewrite(ExtTestCase):
7+
8+
def test_code_needing_rewriting(self):
9+
self.assertEqual(1, len(code_needing_rewriting("BartForConditionalGeneration")))
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,17 @@ def get_parser_validate() -> ArgumentParser:
309309
help="catches exception, report them in the summary",
310310
)
311311
parser.add_argument(
312-
"-p",
313312
"--patch",
314313
default=True,
315314
action=BooleanOptionalAction,
316315
help="applies patches before exporting",
317316
)
317+
parser.add_argument(
318+
"--rewrite",
319+
default=True,
320+
action=BooleanOptionalAction,
321+
help="applies rewrite before exporting",
322+
)
318323
parser.add_argument(
319324
"--stop-if-static",
320325
default=0,
@@ -411,6 +416,7 @@ def _cmd_validate(argv: List[Any]):
411416
dtype=args.dtype,
412417
device=args.device,
413418
patch=args.patch,
419+
rewrite=args.rewrite,
414420
stop_if_static=args.stop_if_static,
415421
optimization=args.opt,
416422
exporter=args.export,

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def torch_export_patches(
103103
patch: bool = True,
104104
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
105105
rewrite: Optional[List[Callable]] = None,
106+
dump_rewriting: Optional[str] = None,
106107
) -> Callable:
107108
"""
108109
Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -130,6 +131,7 @@ def torch_export_patches(
130131
this is done by function :func:`transform_method
131132
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
132133
its documentation provides possible values
134+
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
133135
:param verbose: to show which patches is applied
134136
135137
The list of available patches.
@@ -186,7 +188,9 @@ def torch_export_patches(
186188
if rewrite:
187189
from .patch_module import torch_export_rewrite
188190

189-
with torch_export_rewrite(rewrite=rewrite, verbose=verbose), torch_export_patches( # type: ignore[var-annotated]
191+
with torch_export_rewrite( # type: ignore[var-annotated]
192+
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
193+
), torch_export_patches(
190194
patch_sympy=patch_sympy,
191195
patch_torch=patch_torch,
192196
patch_transformers=patch_transformers,

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,18 @@ def _rewrite_if(
177177
else_ret = else_exprs[0]
178178
then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
179179
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
180+
assert type(then_ret.value) is type(else_ret.value), (
181+
f"Inconsistencies return then value={then_ret.value}, "
182+
f"else value={else_ret.value}"
183+
)
184+
if isinstance(then_ret.value, (ast.Tuple, ast.list)):
185+
assert len(then_ret.value.elts) == len(else_ret.value.elts), (
186+
f"Unexpected number of elements on both branches, "
187+
f"then:{then_ret.value.elts}, else:{else_ret.value.elts}"
188+
)
189+
n_returned_values = len(then_ret.value.elts)
190+
else:
191+
n_returned_values = 0
180192
else:
181193
self._check(
182194
tgt_mapping,
@@ -207,6 +219,7 @@ def _rewrite_if(
207219
if len(else_rets) == 1
208220
else ast.Tuple([self._clone(r) for r in else_rets], ctx=ast.Load())
209221
)
222+
n_returned_values = len(then_rets) if len(then_rets) > 1 else 0
210223

211224
# build local funcs
212225
then_def = ast.FunctionDef(
@@ -258,7 +271,7 @@ def _rewrite_if(
258271
],
259272
keywords=[],
260273
)
261-
return then_def, else_def, call, drop
274+
return then_def, else_def, call, drop, n_returned_values
262275

263276
def _filter_target(self, node, tgt_mapping):
264277
"""
@@ -330,17 +343,32 @@ def visit_If(self, node):
330343
# the targets we need to export
331344
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
332345

333-
then_def, else_def, call, dropped = self._rewrite_if(
346+
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
334347
node,
335348
then_assigns,
336349
else_assigns,
337350
tgt_mapping=tgt_mapping,
338351
known_local_variables=known_local_variables,
339352
)
340353
if dropped and isinstance(tgt, ast.Tuple):
341-
tgt = ast.Tuple(
342-
tuple(t for t in tgt.elts if t.id not in dropped), ctx=ast.Store()
354+
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
355+
elif isinstance(tgt, ast.Tuple):
356+
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
357+
else:
358+
tgt_elts = [tgt]
359+
360+
if n_returned_values == 0:
361+
assert len(tgt_elts) == 1, (
362+
f"Inconsistencies between n_returned_values={n_returned_values}, "
363+
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
343364
)
365+
tgt = tgt_elts[0]
366+
else:
367+
assert n_returned_values == len(tgt_elts), (
368+
f"Inconsistencies between n_returned_values={n_returned_values}, "
369+
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
370+
)
371+
tgt = ast.Tuple(tgt_elts, ctx=ast.Store())
344372

345373
added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
346374
assign = ast.Assign(targets=[tgt], value=call)
@@ -364,7 +392,7 @@ def visit_If(self, node):
364392
)
365393
then_expr = then_ret.value
366394
else_expr = else_ret.value
367-
then_def, else_def, call, dropped = self._rewrite_if(
395+
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
368396
node, [then_expr], [else_expr], known_local_variables=known_local_variables
369397
)
370398
ret = ast.Return(call)
@@ -809,7 +837,9 @@ def forward(self, x, y):
809837

810838
@contextlib.contextmanager
811839
def torch_export_rewrite(
812-
rewrite: Optional[List[Union[Tuple[type, str], Callable]]] = None, verbose: int = 0
840+
rewrite: Optional[List[Union[Tuple[type, str], Callable]]] = None,
841+
dump_rewriting: Optional[str] = None,
842+
verbose: int = 0,
813843
):
814844
"""
815845
Automatically rewrite the methods given in `rewrite` to export
@@ -818,6 +848,7 @@ def torch_export_rewrite(
818848
:param rewrite: methods of functions to rewrite, if not empty, the function may try
819849
to discover them, a method is defined by its class (a type) and its name
820850
if the class is local, by itself otherwise
851+
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
821852
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
822853
``verbose=1`` shows the rewritten function,
823854
``verbose=2`` shows the rewritten code as well
@@ -890,6 +921,7 @@ def forward(self, x, y):
890921
f"__globals__={sorted(me.__globals__)}"
891922
)
892923
mod = sys.modules[module]
924+
cls_name = module
893925
cls = mod
894926
name = name
895927
to_rewrite = me
@@ -916,7 +948,19 @@ def forward(self, x, y):
916948
if verbose:
917949
print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
918950
keep[cls, name] = to_rewrite
951+
if dump_rewriting:
952+
filename = f"{dump_rewriting}.{kind}.{cls_name}.{name}.original.py"
953+
if verbose:
954+
print(f"[torch_export_rewrite] dump original code in {filename!r}")
955+
with open(filename, "w") as f:
956+
f.write(inspect.getsource(to_rewrite))
919957
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0))
958+
if dump_rewriting:
959+
filename = f"{dump_rewriting}.{kind}.{cls_name}.{name}.rewritten.py"
960+
if verbose:
961+
print(f"[torch_export_rewrite] dump rewritten code in {filename!r}")
962+
with open(filename, "w") as f:
963+
f.write(rewr.code)
920964
setattr(cls, name, rewr.func)
921965

922966
try:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import functools
33
import textwrap
4-
from typing import Dict, List
4+
from typing import Any, Dict, List, Optional
55

66
__date__ = "2025-03-26"
77

@@ -199,6 +199,28 @@
199199
"""
200200

201201

202+
def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
203+
"""
204+
Returns a known list of methods or functions to rewrite because of control flow
205+
for a specific model class.
206+
207+
:param cls_name: name of the class
208+
:return: a list of rewriting
209+
210+
.. runpython::
211+
:showcode:
212+
213+
from onnx_diagnostic.torch_models.hghub.hub_data import code_needing_rewriting
214+
215+
print(code_needing_rewriting("BartForConditionalGeneration"))
216+
"""
217+
if cls_name in {"BartForConditionalGeneration", "BartEncoderLayer"}:
218+
import transformers
219+
220+
return [transformers.models.bart.modeling_bart.BartEncoderLayer.forward]
221+
return None
222+
223+
202224
@functools.cache
203225
def load_models_testing() -> List[str]:
204226
"""Returns model ids for testing."""

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ...helpers.config_helper import update_config
77
from ...tasks import reduce_model_config, random_input_kwargs
88
from .hub_api import task_from_arch, task_from_id, get_pretrained_config
9+
from .hub_data import code_needing_rewriting
910

1011

1112
def get_untrained_model_with_inputs(
@@ -40,7 +41,8 @@ def get_untrained_model_with_inputs(
4041
:param add_second_input: provides a second inputs to check a model
4142
supports different shapes
4243
:param subfolder: subfolder to use for this model id
43-
:return: dictionary with a model, inputs, dynamic shapes, and the configuration
44+
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
45+
some necessary rewriting as well
4446
4547
Example:
4648
@@ -170,6 +172,10 @@ def get_untrained_model_with_inputs(
170172
if k.startswith(("inputs", "dynamic_shapes")) and isinstance(v, dict):
171173
update[k] = filter_out_unexpected_inputs(model, v, verbose=verbose)
172174
res.update(update)
175+
176+
rewrite = code_needing_rewriting(model.__class__.__name__)
177+
if rewrite:
178+
res["rewrite"] = rewrite
173179
return res
174180

175181

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def validate_model(
222222
optimization: Optional[str] = None,
223223
quiet: bool = False,
224224
patch: bool = False,
225+
rewrite: bool = False,
225226
stop_if_static: int = 1,
226227
dump_folder: Optional[str] = None,
227228
drop_inputs: Optional[List[str]] = None,
@@ -250,6 +251,8 @@ def validate_model(
250251
:param quiet: if quiet, catches exception if any issue
251252
:param patch: applies patches (``patch_transformers=True``) before exporting,
252253
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
254+
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
255+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
253256
:param stop_if_static: stops if a dynamic dimension becomes static,
254257
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
255258
:param dump_folder: dumps everything in a subfolder of this one
@@ -270,6 +273,9 @@ def validate_model(
270273
271274
* ``PRINT_CONFIG``: prints the model configuration
272275
"""
276+
assert (
277+
not rewrite or patch
278+
), f"rewrite={rewrite}, patch={patch}, patch must be True to enable rewriting"
273279
summary = version_summary()
274280
summary.update(
275281
dict(
@@ -281,6 +287,7 @@ def validate_model(
281287
version_optimization=optimization or "",
282288
version_quiet=str(quiet),
283289
version_patch=str(patch),
290+
version_rewrite=str(rewrite),
284291
version_dump_folder=dump_folder or "",
285292
version_drop_inputs=str(list(drop_inputs or "")),
286293
version_ortfusiontype=ortfusiontype or "",
@@ -336,6 +343,13 @@ def validate_model(
336343
)
337344
data["input_options"] = iop
338345
data["model_options"] = mop
346+
if "rewrite" in data:
347+
if rewrite:
348+
summary["model_rewrite"] = str(data["rewrite"])
349+
if verbose:
350+
print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
351+
else:
352+
del data["rewrite"]
339353
if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
340354
print("[validate_model] -- PRINT CONFIG")
341355
print("-- type(config)", type(data["configuration"]))
@@ -446,6 +460,8 @@ def validate_model(
446460
patch_transformers=True,
447461
stop_if_static=stop_if_static,
448462
verbose=max(0, verbose - 1),
463+
rewrite=data.get("rewrite", None),
464+
dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
449465
) as modificator:
450466
data["inputs_export"] = modificator(data["inputs"]) # type: ignore
451467

0 commit comments

Comments
 (0)