Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.5.0
+++++

* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)

0.4.4
+++++

Expand Down
15 changes: 15 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,21 @@ def forward(self, x, y):
self.assertEqualAny(expected, ep.module()(x, y))
self.assertEqualAny(expected_, ep.module()(-x, y))

def test_rewrite_PLBartEncoderLayer(self):
from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer

rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose)
self.assertIn(
(
"torch.cond(hidden_states.dtype == torch.float16 and "
"(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()), "
"branch_cond_then_1, branch_cond_else_1, [hidden_states])"
),
rewritten.code,
)
print()
print(rewritten.code)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 10 additions & 2 deletions onnx_diagnostic/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import inspect
import re
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import transformers


Expand Down Expand Up @@ -42,8 +42,16 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
setattr(config, k, v)


def _pick(config, *atts):
def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
"""Returns the first value found in the configuration."""
if (
exceptions
and hasattr(config, "architectures")
and len(config.architectures) == 1
and config.architectures[0] in exceptions
):
excs = exceptions[config.architectures[0]]
return excs(config)
for a in atts:
if isinstance(a, str):
if hasattr(config, a):
Expand Down
6 changes: 6 additions & 0 deletions onnx_diagnostic/tasks/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
("decoder_attention_heads", "encoder_attention_heads"),
),
)
# exceptions = {
# "PLBartForConditionalGeneration": (
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
# )
# }
kwargs = dict(
batch_size=2,
sequence_length=30,
Expand All @@ -181,6 +186,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"num_key_value_heads",
"num_heads",
(sum, "encoder_attention_heads", "decoder_attention_heads"),
# exceptions=exceptions,
)
),
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
Expand Down
25 changes: 25 additions & 0 deletions onnx_diagnostic/torch_export_patches/patch_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import inspect
import types
import textwrap
Expand Down Expand Up @@ -341,11 +342,34 @@ def visit(self, node):
return node


class _SelectiveAssignNormalizer(ast.NodeTransformer):
def visit_If(self, node):
self.generic_visit(node)
node.body = [self._transform_if_needed(stmt) for stmt in node.body]
node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse]
return node

def _transform_if_needed(self, stmt):
if isinstance(stmt, ast.AugAssign):
return ast.Assign(
targets=[stmt.target],
value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value),
)
if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
return ast.Assign(targets=[stmt.target], value=stmt.value)
return self.visit(stmt)


def inplace_add_parent(tree: "ast.Node"):
"""Adds parents to an AST tree."""
_AddParentTransformer().visit(tree)


def normalize_assignment_in_test(tree: "ast.Node"):
"""Split AugAssign into BinOp and Assign to simplify whatever comes after."""
_SelectiveAssignNormalizer().visit(tree)


def transform_method(
func: Callable,
prefix: str = "branch_cond",
Expand Down Expand Up @@ -451,6 +475,7 @@ def forward(self, x, y):
skip_objects=modules,
args_names=set(sig.parameters),
)
normalize_assignment_in_test(tree)
inplace_add_parent(tree)
new_tree = transformer.visit(tree)
if verbose > 1:
Expand Down
1 change: 1 addition & 0 deletions onnx_diagnostic/torch_models/hghub/hub_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
Phi3ForCausalLM,text-generation
PhiForCausalLM,text-generation
Pix2StructForConditionalGeneration,image-to-text
PLBartForConditionalGeneration,text2text-generation
PoolFormerModel,image-feature-extraction
PvtForImageClassification,image-classification
Qwen2ForCausalLM,text-generation
Expand Down
Loading