Skip to content

Commit 93a8445

Browse files
committed
Improves rewriting
1 parent 54e8373 commit 93a8445

File tree

9 files changed

+161
-12
lines changed

9 files changed

+161
-12
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ onnx_diagnostic.torch_export_patches
99
patch_expressions
1010
patch_inputs
1111
patch_module
12-
12+
patch_module_helper
1313

1414
.. automodule:: onnx_diagnostic.torch_export_patches
1515
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
onnx_diagnostic.torch_export_patches.patch_module
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.patch_module
6+
:members:
7+
:no-undoc-members:
8+
:exclude-members: torch_export_rewrite

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ShapeFinder,
1414
RewriteControlFlow,
1515
)
16+
from onnx_diagnostic.torch_export_patches.patch_module_helper import ast_or_into_bitor
1617

1718

1819
class _ModelForATest(torch.nn.Module):
@@ -396,15 +397,25 @@ def forward(self, x, y):
396397
def test_rewrite_test_in_PLBartEncoderLayer(self):
397398
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer
398399

399-
rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
400+
def filter_node(node) -> bool:
401+
return isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
402+
403+
rewritten = transform_method(
404+
PLBartEncoderLayer.forward,
405+
verbose=self.verbose,
406+
filter_node=filter_node,
407+
pre_rewriter=ast_or_into_bitor,
408+
)
400409
self.assertIn(
401410
(
402411
"torch.cond(hidden_states.dtype == torch.float16 and "
403-
"(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()), "
412+
"torch.isinf(hidden_states).any()"
413+
" | torch.isnan(hidden_states).any(), "
404414
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
405415
),
406416
rewritten.code,
407417
)
418+
self.assertNotIn("torch.cond(output_attentions", rewritten.code)
408419

409420
@hide_stdout()
410421
def test_torch_export_patch_method_tuple(self):

_unittests/ut_torch_models/test_hghub_mode_rewrite.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import ExtTestCase
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_errors
34
from onnx_diagnostic.torch_models.hghub.hub_data import code_needing_rewriting
5+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
6+
from onnx_diagnostic.torch_export_patches import torch_export_patches
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
48

59

610
class TestHuggingFaceHubModelRewrite(ExtTestCase):
711

812
def test_code_needing_rewriting(self):
9-
self.assertEqual(1, len(code_needing_rewriting("BartForConditionalGeneration")))
13+
self.assertEqual(2, len(code_needing_rewriting("BartForConditionalGeneration")))
14+
15+
@hide_stdout()
16+
@ignore_errors(OSError)
17+
def test_export_rewritin_bart(self):
18+
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
19+
data = get_untrained_model_with_inputs(mid, verbose=1)
20+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
21+
dump_folder = self.get_dump_file("test_export_rewritin_bart")
22+
print(self.string_type(inputs))
23+
print(self.string_type(ds))
24+
with torch_export_patches(
25+
patch_transformers=True,
26+
rewrite=code_needing_rewriting("BartForConditionalGeneration"),
27+
dump_rewriting=dump_folder,
28+
):
29+
model(**inputs)
30+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds))
1031

1132

1233
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,12 @@ def get_parser_validate() -> ArgumentParser:
343343
help="drops the following inputs names, it should be a list "
344344
"with comma separated values",
345345
)
346+
parser.add_argument(
347+
"--opset",
348+
type=int,
349+
default=18,
350+
help="onnx opset to use, 18 by default",
351+
)
346352
parser.add_argument(
347353
"--subfolder",
348354
help="subfolder where to find the model and the configuration",
@@ -426,6 +432,7 @@ def _cmd_validate(argv: List[Any]):
426432
input_options=args.iop,
427433
model_options=args.mop,
428434
subfolder=args.subfolder,
435+
opset=args.opset,
429436
)
430437
print("")
431438
print("-- summary --")

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,20 @@ class RewriteControlFlow(ast.NodeTransformer):
8989
:param skip_objects: to skip variable names if included in that list
9090
such as modules
9191
:param args_names: defines the local variables
92+
:param filter_nodes: a function which is used to decide which node
93+
to rewrite, True by default
94+
:param pre_rewriter: a rewriter applied before the automated rewriting
95+
:param post_rewriter: a rewriter applied after the automated rewriting
9296
"""
9397

9498
def __init__(
9599
self,
96100
prefix: str = "branch_cond",
97101
skip_objects: Optional[Dict[str, object]] = None,
98102
args_names: Optional[Set[str]] = None,
103+
filter_node: Optional[Callable["ast.Node", bool]] = None,
104+
pre_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
105+
post_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
99106
):
100107
self.counter_test = 0
101108
self.counter_loop = 0
@@ -104,6 +111,9 @@ def __init__(
104111
self.skip_objects = skip_objects or {}
105112
self.args_names = args_names or set()
106113
self.local_variables = self.args_names.copy()
114+
self.filter_node = filter_node or (lambda _node: True)
115+
self.pre_rewriter = pre_rewriter or (lambda node: node)
116+
self.post_rewriter = post_rewriter or (lambda node: node)
107117

108118
def generic_visit(self, node):
109119
return super().generic_visit(node)
@@ -320,6 +330,11 @@ def _make_targets(self, node, then_assigns, else_assigns):
320330
return tgt, tgt_mapping
321331

322332
def visit_If(self, node):
333+
if not self.filter_node(node):
334+
return [node]
335+
336+
node = self.pre_rewriter(node)
337+
323338
# First recurse into subnodes
324339
known_local_variables = self.local_variables.copy()
325340
node = self.generic_visit(node)
@@ -380,7 +395,7 @@ def visit_If(self, node):
380395
ast.copy_location(assign, node)
381396
ast.fix_missing_locations(assign)
382397
self.local_variables = known_local_variables | added
383-
return [then_def, else_def, assign]
398+
return [self.post_rewriter(n) for n in [then_def, else_def, assign]]
384399

385400
# Case 2: return in both branches, we assume both branches return the same results.
386401
then_ret = node.body[-1]
@@ -403,7 +418,7 @@ def visit_If(self, node):
403418
ret = ast.Return(call)
404419
ast.copy_location(ret, node)
405420
ast.fix_missing_locations(ret)
406-
return [then_def, else_def, ret]
421+
return [self.post_rewriter(n) for n in [then_def, else_def, ret]]
407422

408423
def _find_loop_vars(self, node):
409424
assert isinstance(node, ast.For), f"Unexpected type {type(node)} for node"
@@ -462,6 +477,11 @@ def _find_loop_vars(self, node):
462477
)
463478

464479
def visit_For(self, node):
480+
if not self.filter_node(node):
481+
return [node]
482+
483+
node = self.pre_rewriter(node)
484+
465485
# For nested loops.
466486
self.generic_visit(node)
467487
# look for variables, loop, inputs and outputs of the body
@@ -622,7 +642,7 @@ def visit_For(self, node):
622642
ctx=ast.Store(),
623643
)
624644
assign = ast.Assign(targets=[target], value=call)
625-
return [func_def, assign]
645+
return [self.post_rewriter(func_def), self.post_rewriter(assign)]
626646

627647

628648
class RewrittenMethod:
@@ -697,6 +717,9 @@ def transform_method(
697717
func: Callable,
698718
prefix: str = "branch_cond",
699719
verbose: int = 0,
720+
filter_node: Optional[Callable["ast.Node", bool]] = None,
721+
pre_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
722+
post_rewriter: Optional[Callable["ast.Node", "ast.Node"]] = None,
700723
) -> RewrittenMethod:
701724
"""
702725
Returns a new function based on `func` where every test (if)
@@ -717,6 +740,9 @@ def transform_method(
717740
:param func: method or function to rewrite
718741
:param prefix: prefix used to create the functions for the branches
719742
:param verbose: verbosity
743+
:param filter_node: a function which tells which node to rewrite
744+
:param pre_rewriter: a rewriter applied before the automated rewriting
745+
:param post_rewriter: a rewriter applied after the automated rewriting
720746
:return: rewritten method
721747
722748
An example with **return**:
@@ -801,6 +827,9 @@ def forward(self, x, y):
801827
prefix=prefix,
802828
skip_objects=modules,
803829
args_names=set(sig.parameters),
830+
filter_node=filter_node,
831+
pre_rewriter=pre_rewriter,
832+
post_rewriter=post_rewriter,
804833
)
805834
normalize_assignment_in_test(tree)
806835
inplace_add_parent(tree)
@@ -912,7 +941,22 @@ def forward(self, x, y):
912941
cls, name = me
913942
to_rewrite = getattr(cls, name)
914943
kind = "method"
944+
kws = {}
915945
else:
946+
if isinstance(me, dict):
947+
assert "function" in me and (
948+
"filter_node" in me or "pre_rewriter" in me or "post_rewriter" in me
949+
), (
950+
f"If the rewriting code is defined as a dictionary, key "
951+
f"'function' must be defined, other arguments must be understood by "
952+
f"{transform_method.__name__}, "
953+
f"the given value is {me!r}."
954+
)
955+
kws = me
956+
me = me["function"]
957+
del kws["function"]
958+
else:
959+
kws = {}
916960
name = me.__qualname__
917961
spl = name.split(".")
918962
if len(spl) == 1:
@@ -958,8 +1002,9 @@ def forward(self, x, y):
9581002
if verbose:
9591003
print(f"[torch_export_rewrite] dump original code in {filename!r}")
9601004
with open(filename, "w") as f:
961-
f.write(inspect.getsource(to_rewrite))
962-
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0))
1005+
code = inspect.getsource(to_rewrite)
1006+
f.write(code)
1007+
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
9631008
if dump_rewriting:
9641009
filename = f"{dump_rewriting}.{kind}.{cls_name}.{name}.rewritten.py"
9651010
if verbose:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import ast
2+
3+
4+
class OrToBitOrTransformer(ast.NodeTransformer):
5+
def visit_BoolOp(self, node):
6+
self.generic_visit(node)
7+
if isinstance(node.op, ast.Or):
8+
new_node = node.values[0]
9+
for value in node.values[1:]:
10+
new_node = ast.BinOp(left=new_node, op=ast.BitOr(), right=value)
11+
return ast.copy_location(new_node, node)
12+
return node
13+
14+
15+
def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
16+
"""Replaces every operator ``or`` into ``|``."""
17+
new_node = OrToBitOrTransformer().visit(node)
18+
return new_node

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import io
23
import functools
34
import textwrap
@@ -214,10 +215,31 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
214215
215216
print(code_needing_rewriting("BartForConditionalGeneration"))
216217
"""
217-
if cls_name in {"BartForConditionalGeneration", "BartEncoderLayer"}:
218+
if cls_name in {
219+
"BartEncoderLayer",
220+
"BartForConditionalGeneration",
221+
"PLBartEncoderLayer",
222+
"PLBartForConditionalGeneration",
223+
}:
218224
import transformers
225+
from ...torch_export_patches.patch_module_helper import ast_or_into_bitor
219226

220-
return [transformers.models.bart.modeling_bart.BartEncoderLayer.forward]
227+
bd = dict(
228+
filter_node=(
229+
lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
230+
),
231+
pre_rewriter=ast_or_into_bitor,
232+
)
233+
234+
def _add(f):
235+
g = bd.copy()
236+
g["function"] = f
237+
return g
238+
239+
return [
240+
_add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
241+
_add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
242+
]
221243
return None
222244

223245

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def validate_model(
230230
input_options: Optional[Dict[str, Any]] = None,
231231
model_options: Optional[Dict[str, Any]] = None,
232232
subfolder: Optional[str] = None,
233+
opset: Optional[int] = None,
233234
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
234235
"""
235236
Validates a model.
@@ -265,6 +266,7 @@ def validate_model(
265266
:param model_options: additional options when creating the model such as
266267
``num_hidden_layers`` or ``attn_implementation``
267268
:param subfolder: version or subfolders to uses when retrieving a model id
269+
:param opset: onnx opset to use for the conversion
268270
:return: two dictionaries, one with some metrics,
269271
another one with whatever the function produces
270272
@@ -295,6 +297,8 @@ def validate_model(
295297
version_exporter=exporter or "",
296298
)
297299
)
300+
if opset:
301+
summary["version_opset"] = opset
298302

299303
folder_name = None
300304
if dump_folder:
@@ -343,6 +347,8 @@ def validate_model(
343347
)
344348
data["input_options"] = iop
345349
data["model_options"] = mop
350+
if opset:
351+
data["model_opset"] = opset
346352
if "rewrite" in data:
347353
if rewrite:
348354
summary["model_rewrite"] = str(data["rewrite"])
@@ -992,6 +998,9 @@ def call_torch_export_onnx(
992998
summary["export_dynamo"] = dynamo
993999
summary["export_args"] = string_type(args, with_shape=True)
9941000
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1001+
opset = data.get("model_opset", None)
1002+
if opset:
1003+
summary["export_opset"] = opset
9951004

9961005
if dynamo:
9971006
export_export_kwargs = dict(dynamo=True, dynamic_shapes=ds)
@@ -1012,6 +1021,8 @@ def call_torch_export_onnx(
10121021
print("[call_torch_export_onnx] dynamo=False so...")
10131022
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
10141023
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
1024+
if opset:
1025+
export_export_kwargs["opset_version"] = opset
10151026
if verbose:
10161027
print(
10171028
f"[call_torch_export_onnx] export_export_kwargs="
@@ -1123,6 +1134,9 @@ def call_torch_export_custom(
11231134
strict = "-strict" in exporter
11241135
args, kwargs = split_args_kwargs(data["inputs_export"])
11251136
ds = data.get("dynamic_shapes", None)
1137+
opset = data.get("model_opset", None)
1138+
if opset:
1139+
summary["export_opset"] = opset
11261140
if verbose:
11271141
print(
11281142
f"[call_torch_export_custom] exporter={exporter!r}, "
@@ -1163,6 +1177,9 @@ def call_torch_export_custom(
11631177
return_optimize_report=True,
11641178
verbose=max(verbose - 2, 0),
11651179
)
1180+
if opset:
1181+
kws["target_opset"] = opset
1182+
assert opset
11661183

11671184
epo, opt_stats = _quiet_or_not_quiet(
11681185
quiet,

0 commit comments

Comments
 (0)