diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 28ffe2b7..bae4cfe6 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.5.0 +++++ +* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test) + 0.4.4 +++++ diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 4e8fb90c..63e40591 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -349,6 +349,21 @@ def forward(self, x, y): self.assertEqualAny(expected, ep.module()(x, y)) self.assertEqualAny(expected_, ep.module()(-x, y)) + def test_rewrite_PLBartEncoderLayer(self): + from transformers.models.plbart.modeling_plbart import PLBartEncoderLayer + + rewritten = transform_method(PLBartEncoderLayer.forward, verbose=self.verbose) + self.assertIn( + ( + "torch.cond(hidden_states.dtype == torch.float16 and " + "(torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()), " + "branch_cond_then_1, branch_cond_else_1, [hidden_states])" + ), + rewritten.code, + ) + print() + print(rewritten.code) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/helpers/config_helper.py b/onnx_diagnostic/helpers/config_helper.py index 545a8af0..4b8ac43f 100644 --- a/onnx_diagnostic/helpers/config_helper.py +++ b/onnx_diagnostic/helpers/config_helper.py @@ -2,7 +2,7 @@ import importlib import inspect import re -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import transformers @@ -42,8 +42,16 @@ def update_config(config: Any, mkwargs: Dict[str, Any]): setattr(config, k, v) -def _pick(config, *atts): +def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None): """Returns the first value found in the configuration.""" + if ( + exceptions + and hasattr(config, "architectures") + and len(config.architectures) == 1 + and config.architectures[0] in exceptions + ): + excs = exceptions[config.architectures[0]] + return excs(config) for a in atts: if isinstance(a, str): if hasattr(config, a): diff --git a/onnx_diagnostic/tasks/text2text_generation.py b/onnx_diagnostic/tasks/text2text_generation.py index dde8841c..9e30d45a 100644 --- a/onnx_diagnostic/tasks/text2text_generation.py +++ b/onnx_diagnostic/tasks/text2text_generation.py @@ -164,6 +164,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: ("decoder_attention_heads", "encoder_attention_heads"), ), ) + # exceptions = { + # "PLBartForConditionalGeneration": ( + # lambda c: c.encoder_attention_heads + c.decoder_attention_heads + # ) + # } kwargs = dict( batch_size=2, sequence_length=30, @@ -181,6 +186,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: "num_key_value_heads", "num_heads", (sum, "encoder_attention_heads", "decoder_attention_heads"), + # exceptions=exceptions, ) ), encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"), diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 4703e89a..0edadab3 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -1,4 +1,5 @@ import ast +import copy import inspect import types import textwrap @@ -341,11 +342,34 @@ def visit(self, node): return node +class _SelectiveAssignNormalizer(ast.NodeTransformer): + def visit_If(self, node): + self.generic_visit(node) + node.body = [self._transform_if_needed(stmt) for stmt in node.body] + node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse] + return node + + def _transform_if_needed(self, stmt): + if isinstance(stmt, ast.AugAssign): + return ast.Assign( + targets=[stmt.target], + value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value), + ) + if isinstance(stmt, ast.AnnAssign) and stmt.value is not None: + return ast.Assign(targets=[stmt.target], value=stmt.value) + return self.visit(stmt) + + def inplace_add_parent(tree: "ast.Node"): """Adds parents to an AST tree.""" _AddParentTransformer().visit(tree) +def normalize_assignment_in_test(tree: "ast.Node"): + """Split AugAssign into BinOp and Assign to simplify whatever comes after.""" + _SelectiveAssignNormalizer().visit(tree) + + def transform_method( func: Callable, prefix: str = "branch_cond", @@ -451,6 +475,7 @@ def forward(self, x, y): skip_objects=modules, args_names=set(sig.parameters), ) + normalize_assignment_in_test(tree) inplace_add_parent(tree) new_tree = transformer.visit(tree) if verbose > 1: diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index 9ae45d89..6cba901c 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -97,6 +97,7 @@ Phi3ForCausalLM,text-generation PhiForCausalLM,text-generation Pix2StructForConditionalGeneration,image-to-text + PLBartForConditionalGeneration,text2text-generation PoolFormerModel,image-feature-extraction PvtForImageClassification,image-classification Qwen2ForCausalLM,text-generation