Skip to content

Commit 4d22b9d

Browse files
authored
Split AugAssign (#85)
* Split AugAssign * fix PR number * fix ut
1 parent b2ab9b7 commit 4d22b9d

File tree

6 files changed

+59
-2
lines changed

6 files changed

+59
-2
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
8+
79
0.4.4
810
+++++
911

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,21 @@ def forward(self, x, y):
349349
self.assertEqualAny(expected, ep.module()(x, y))
350350
self.assertEqualAny(expected_, ep.module()(-x, y))
351351

352+
def test_rewrite_PLBartEncoderLayer(self):
353+
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer
354+
355+
rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
356+
self.assertIn(
357+
(
358+
"torch.cond(hidden_states.dtype == torch.float16 and "
359+
"(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()), "
360+
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
361+
),
362+
rewritten.code,
363+
)
364+
print()
365+
print(rewritten.code)
366+
352367

353368
if __name__ == "__main__":
354369
unittest.main(verbosity=2)

onnx_diagnostic/helpers/config_helper.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import importlib
33
import inspect
44
import re
5-
from typing import Any, Dict, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, Optional, Tuple, Union
66
import transformers
77

88

@@ -42,8 +42,16 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
4242
setattr(config, k, v)
4343

4444

45-
def _pick(config, *atts):
45+
def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
4646
"""Returns the first value found in the configuration."""
47+
if (
48+
exceptions
49+
and hasattr(config, "architectures")
50+
and len(config.architectures) == 1
51+
and config.architectures[0] in exceptions
52+
):
53+
excs = exceptions[config.architectures[0]]
54+
return excs(config)
4755
for a in atts:
4856
if isinstance(a, str):
4957
if hasattr(config, a):

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
164164
("decoder_attention_heads", "encoder_attention_heads"),
165165
),
166166
)
167+
# exceptions = {
168+
# "PLBartForConditionalGeneration": (
169+
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
170+
# )
171+
# }
167172
kwargs = dict(
168173
batch_size=2,
169174
sequence_length=30,
@@ -181,6 +186,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
181186
"num_key_value_heads",
182187
"num_heads",
183188
(sum, "encoder_attention_heads", "decoder_attention_heads"),
189+
# exceptions=exceptions,
184190
)
185191
),
186192
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import copy
23
import inspect
34
import types
45
import textwrap
@@ -341,11 +342,34 @@ def visit(self, node):
341342
return node
342343

343344

345+
class _SelectiveAssignNormalizer(ast.NodeTransformer):
346+
def visit_If(self, node):
347+
self.generic_visit(node)
348+
node.body = [self._transform_if_needed(stmt) for stmt in node.body]
349+
node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse]
350+
return node
351+
352+
def _transform_if_needed(self, stmt):
353+
if isinstance(stmt, ast.AugAssign):
354+
return ast.Assign(
355+
targets=[stmt.target],
356+
value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value),
357+
)
358+
if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
359+
return ast.Assign(targets=[stmt.target], value=stmt.value)
360+
return self.visit(stmt)
361+
362+
344363
def inplace_add_parent(tree: "ast.Node"):
345364
"""Adds parents to an AST tree."""
346365
_AddParentTransformer().visit(tree)
347366

348367

368+
def normalize_assignment_in_test(tree: "ast.Node"):
369+
"""Split AugAssign into BinOp and Assign to simplify whatever comes after."""
370+
_SelectiveAssignNormalizer().visit(tree)
371+
372+
349373
def transform_method(
350374
func: Callable,
351375
prefix: str = "branch_cond",
@@ -451,6 +475,7 @@ def forward(self, x, y):
451475
skip_objects=modules,
452476
args_names=set(sig.parameters),
453477
)
478+
normalize_assignment_in_test(tree)
454479
inplace_add_parent(tree)
455480
new_tree = transformer.visit(tree)
456481
if verbose > 1:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
Phi3ForCausalLM,text-generation
9898
PhiForCausalLM,text-generation
9999
Pix2StructForConditionalGeneration,image-to-text
100+
PLBartForConditionalGeneration,text2text-generation
100101
PoolFormerModel,image-feature-extraction
101102
PvtForImageClassification,image-classification
102103
Qwen2ForCausalLM,text-generation

0 commit comments

Comments
 (0)