From c7e3ea1bafe0bf5677d82076cd30ac368a584e39 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 30 Sep 2025 13:57:11 +0200 Subject: [PATCH 1/5] Add a patch for dimension in 0/1 --- .../test_patch_rewrite.py | 13 -- .../test_patch_rewriting.py | 5 + .../test_patch_torch.py | 38 ++++++ .../onnx_export_errors.py | 12 ++ .../patches/patch_torch.py | 126 ++++++++++++++++-- 5 files changed, 171 insertions(+), 23 deletions(-) delete mode 100644 _unittests/ut_torch_export_patches/test_patch_rewrite.py diff --git a/_unittests/ut_torch_export_patches/test_patch_rewrite.py b/_unittests/ut_torch_export_patches/test_patch_rewrite.py deleted file mode 100644 index cd0cbd56..00000000 --- a/_unittests/ut_torch_export_patches/test_patch_rewrite.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest -from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting - - -class TestPatchRewrite(ExtTestCase): - def test_code_needing_rewriting(self): - res = code_needing_rewriting("BartModel") - self.assertEqual(len(res), 2) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_rewriting.py b/_unittests/ut_torch_export_patches/test_patch_rewriting.py index 7c8a0f8e..ea01d16e 100644 --- a/_unittests/ut_torch_export_patches/test_patch_rewriting.py +++ b/_unittests/ut_torch_export_patches/test_patch_rewriting.py @@ -3,6 +3,7 @@ from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( rewrite_loop_for_square_mask, ) +from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting class TestPatchRewriting(ExtTestCase): @@ -33,6 +34,10 @@ def apply_mask(mask, seq): m2 = rewrite_loop_for_square_mask(mask, seq) self.assertEqualArray(m1, m2) + def test_code_needing_rewriting(self): + res = code_needing_rewriting("BartModel") + self.assertEqual(len(res), 2) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 25eb5a73..3c266bfe 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -3,6 +3,8 @@ import torch from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str class TestPatchPatchTorch(ExtTestCase): @@ -236,6 +238,42 @@ def forward(self, x): ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},)) self.assertEqualArray(Model()(x), ep.module()(x)) + def test_oblivious_for_dimension_01(self): + class Model(torch.nn.Module): + def forward(self, x, ind1, ind2): + return x[ind1, ind2] + + inputs = ( + torch.randn(2, 1024), + torch.tensor([[0, 1]], dtype=torch.int64).T, + torch.arange(1024, dtype=torch.int64), + ) + model = Model() + expected = model(*inputs) + + dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"}) + dynamic_shapes = use_dyn_not_str(dynamic_string) + with self.subTest(name="export 0/1 specialized due to hint of 1 for dimension"): + try: + torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + raise AssertionError("torch fixed that case") + except ValueError as e: + self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e)) + + with self.subTest(name="expected shape should be broadcastable to"): + try: + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + raise AssertionError("torch fixed that case") + except RuntimeError as e: + self.assertIn("expected shape should be broadcastable to", str(e)) + + with self.subTest(name="patch for 0/1"): + with torch_export_patches(): + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + got = ep.module()(*inputs) + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index f366bc80..210e98e8 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -341,6 +341,7 @@ def torch_export_patches( patched_infer_size, patched_vmap, patched__broadcast_shapes, + patched__constrain_user_specified_dimhint_range, _catch_produce_guards_and_solve_constraints, patch__check_input_constraints_for_graph, ) @@ -371,6 +372,14 @@ def torch_export_patches( torch._refs._broadcast_shapes = patched__broadcast_shapes torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes + # torch._export.non_strict_utils._constrain_user_specified_dimhint_range + f___constrain_user_specified_dimhint_range = ( + torch._export.non_strict_utils._constrain_user_specified_dimhint_range + ) + torch._export.non_strict_utils._constrain_user_specified_dimhint_range = ( + patched__constrain_user_specified_dimhint_range + ) + # torch._export.non_strict_utils.produce_guards_and_solve_constraints if patch_torch and catch_constraints: if verbose: @@ -569,6 +578,9 @@ def torch_export_patches( torch._subclasses.fake_impls.infer_size = f_infer_size torch._refs._broadcast_shapes = f__broadcast_shapes torch._meta_registrations._broadcast_shapes = f__broadcast_shapes + torch._export.non_strict_utils._constrain_user_specified_dimhint_range = ( + f___constrain_user_specified_dimhint_range + ) if verbose: print("[torch_export_patches] restored pytorch functions") diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 29f328da..3fb452f9 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -1,7 +1,7 @@ import inspect import os import traceback -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -65,6 +65,8 @@ def patch__check_input_constraints_for_graph( verbose: int = 0, ) -> None: try: + # PATCHED: catches exception and prints out the information instead of + # stopping the conversion. return previous_function(input_placeholders, flat_args_with_path, range_constraints) except Exception as e: if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): @@ -122,8 +124,7 @@ def patched_infer_size(a, b): if b1 or b2 or b3: expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA else: - # In this case, the current implementation of torch fails (17/12/2024). - # Try model SmolLM. + # PATCHED: generic case, the dimension is known, no need to assert expandedSizes[i] = torch.sym_max(sizeA, sizeB) return tuple(expandedSizes) @@ -132,7 +133,11 @@ def patched__broadcast_shapes(*_shapes): """Patches ``torch._refs._broadcast_shapes``.""" from functools import reduce from torch._prims_common import IntLike - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + guard_or_false, + is_nested_int, + ) shapes = tuple( (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) @@ -142,17 +147,30 @@ def patched__broadcast_shapes(*_shapes): if len(shapes) == 0: return None - # Type checking - # TODO: make common validations available as utils for shape in shapes: - assert isinstance(shape, Sequence) + if not isinstance(shape, Sequence): + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, " + "or a list of ints, got ", + shape, + ) # Computes common shape - common_shape = [ # List[Union[int, torch.SymInt]] - 1, - ] * reduce(max, (len(shape) for shape in shapes)) + common_shape = [1] * reduce(max, (len(shape) for shape in shapes)) for _arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): + if is_nested_int(shape[idx]): + # Broadcasting is allowed for (j0, 1) or (j0, j0); + # not (j0, j1), (j0, 5), etc. + if is_nested_int(common_shape[idx]) and guard_or_false( + shape[idx] == common_shape[idx] + ): + continue + else: + if guard_or_false(shape[idx] == common_shape[idx]): + continue + # PATCHED: two cases, if == for sure, no broadcast, + # otherwise maybe broadcase with max(dimensions) if guard_size_oblivious(common_shape[idx] == 1): if shape[idx] < 0: raise ValueError( @@ -172,6 +190,7 @@ def _check_frozen( ) -> None: if self.frozen: self.counter["ignored_backward_guard"] += 1 + # PATCHED: raised an exception instead of logging. raise AssertionError( f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, " f"this could result in accuracy problems" @@ -338,11 +357,13 @@ def _set_replacement( }, ) + # PATCHED: removed lines # if config.print_specializations: # self.log.warning( # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt # ) # self.log.debug("SPECIALIZATION", stack_info=True) + # PATCHED: replaces logging by raising an exception assert msg != "range_refined_to_singleton", ( f"patched_ShapeEnv: A dynamic dimension becomes static! " f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" @@ -364,6 +385,7 @@ def _log_guard( self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821 ) -> None: self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec) + # PATCHED: removed # It happens too often to be relevant. # sloc, _maybe_extra_debug = self._get_stack_summary(True) # warnings.warn( @@ -464,3 +486,87 @@ def wrapped(*args): return results return wrapped + + +def patched__constrain_user_specified_dimhint_range( + symint: torch.SymInt, + hint: int, + dim: "_DimHint", # noqa: F821 + range_constraints, + shape_env, + keypath: "KeyPath", # noqa: F821 + i: Optional[int] = None, +) -> Optional[str]: + """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``.""" + from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges + + trace_vr = ( + range_constraints[symint.node.expr] + if not is_int(symint) + else ValueRanges(int(symint), int(symint)) + ) + # warn on 0/1 specialization for Dim.AUTO; not an actual error + # PATCHED: remove logging + # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1): + # pathstr = f"inputs{pytree.keystr(keypath)}" + # if i is not None: + # pathstr += f".shape[{i}]" + # msg = ( + # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along " + # f"with a sample input with hint = {hint}." + # ) + # log.warning(msg) + + try: + user_vr = ValueRanges( + lower=0 if dim.min is None else dim.min, + upper=int_oo if dim.max is None else dim.max, + ) + if is_int(symint): + out_vr = trace_vr & user_vr + else: + range_constraints[symint.node.expr] &= user_vr + shape_env.var_to_range[symint.node._expr] &= user_vr + out_vr = range_constraints[symint.node.expr] + + # check for Dim.DYNAMIC specializations; special case error message on 0/1 + if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton(): + path = f"inputs{torch.utils._pytree.keystr(keypath)}" + if i is not None: + path += f".shape[{i}]" + if ( + trace_vr.is_singleton() + and hint in (0, 1) + # PATCHED: line removed + # and not torch.fx.experimental._config.backed_size_oblivious + ): + return None + # PATCHED: line removed + # msg = ( + # f"- Received user-specified dim hint " + # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + # f"but export 0/1 specialized due to hint of " + # f"{hint} for dimension {path}." + # ) + else: + msg = ( + f"- Received user-specified dim hint " + f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + f"but tracing inferred a static shape of " + f"{out_vr.lower} for dimension {path}." + ) + return msg + + except torch.utils._sympy.value_ranges.ValueRangeError: + path = f"inputs{torch.utils._pytree.keystr(keypath)}" + if i is not None: + path += f".shape[{i}]" + msg = ( + f"- Received user-specified min/max range of [{dim.min}, {dim.max}], " + f"conflicting with the inferred min/max range of " + f"[{trace_vr.lower}, {trace_vr.upper}], " + f"for {path}." + ) + return msg + + return None From bccdc0cc64cf207b197b493683cbaaeae76d558e Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 30 Sep 2025 14:07:32 +0200 Subject: [PATCH 2/5] doc --- CHANGELOGS.rst | 2 ++ onnx_diagnostic/torch_export_patches/patches/patch_torch.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 569ef18d..dd498636 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.7.13 ++++++ +* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1} + 0.7.12 ++++++ diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 3fb452f9..6e24835b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -170,7 +170,7 @@ def patched__broadcast_shapes(*_shapes): if guard_or_false(shape[idx] == common_shape[idx]): continue # PATCHED: two cases, if == for sure, no broadcast, - # otherwise maybe broadcase with max(dimensions) + # otherwise maybe broadcast with max(dimensions) if guard_size_oblivious(common_shape[idx] == 1): if shape[idx] < 0: raise ValueError( From 15ab96e7e2374de665c36c8d36e5a9a23983c11a Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 30 Sep 2025 14:33:20 +0200 Subject: [PATCH 3/5] auto --- .../test_patch_torch.py | 56 ++++++++++++++++--- .../torch_export_patches/patch_inputs.py | 16 ++++-- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 3c266bfe..18c3ad6a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -2,7 +2,12 @@ from typing import Callable import torch from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + requires_torch, + requires_transformers, + has_torch, +) from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -252,28 +257,61 @@ def forward(self, x, ind1, ind2): expected = model(*inputs) dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"}) + # ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN}) + dynamic_shapes = use_dyn_not_str(dynamic_string) - with self.subTest(name="export 0/1 specialized due to hint of 1 for dimension"): + with self.subTest( + name="export 0/1 specialized due to hint of 1 for dimension", + dynamic_shapes=dynamic_shapes, + ): try: torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) raise AssertionError("torch fixed that case") except ValueError as e: self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e)) - with self.subTest(name="expected shape should be broadcastable to"): - try: + dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO) + if has_torch("2.9"): + with self.subTest( + name="expected shape should be broadcastable to (>= 2.9)", + dynamic_shapes=dynamic_shapes, + ): + try: + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + raise AssertionError("torch fixed that case") + except RuntimeError as e: + self.assertIn("expected shape should be broadcastable to", str(e)) + + if not has_torch("2.9"): + with self.subTest( + name="expected shape should be broadcastable to (< 2.9)", + dynamic_shapes=dynamic_shapes, + ): with torch.fx.experimental._config.patch(backed_size_oblivious=True): - torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) - raise AssertionError("torch fixed that case") - except RuntimeError as e: - self.assertIn("expected shape should be broadcastable to", str(e)) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + got = ep.module()(*inputs) + self.assertEqualArray(expected, got) - with self.subTest(name="patch for 0/1"): + with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes): with torch_export_patches(): ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) got = ep.module()(*inputs) self.assertEqualArray(expected, got) + if has_torch("2.11"): + # Missing PR https://github.com/pytorch/pytorch/pull/164225 + # Needs more thinking about the patch to apply for this particular example. + with self.subTest( + name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes + ): + with torch_export_patches(), torch.fx.experimental._config.patch( + backed_size_oblivious=True + ): + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + got = ep.module()(*inputs) + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patch_inputs.py b/onnx_diagnostic/torch_export_patches/patch_inputs.py index 86baad12..3c93a70d 100644 --- a/onnx_diagnostic/torch_export_patches/patch_inputs.py +++ b/onnx_diagnostic/torch_export_patches/patch_inputs.py @@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes( return (), updated_kwargs, dynamic_shapes -def use_dyn_not_str(dynamic_shapes: Any) -> Any: +def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any: """ Some functions returns dynamic shapes as string. This functions replaces them with ``torch.export.Dim.DYNAMIC``. + ``default_value=torch.export.Dim.AUTO`` changes the default value. """ if isinstance(dynamic_shapes, list): - return [use_dyn_not_str(a) for a in dynamic_shapes] + return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes] if isinstance(dynamic_shapes, tuple): - return tuple(use_dyn_not_str(a) for a in dynamic_shapes) + return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes) if isinstance(dynamic_shapes, dict): - return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()} + return { + k: use_dyn_not_str(v, default_value=default_value) + for k, v in dynamic_shapes.items() + } if isinstance(dynamic_shapes, set): - return {use_dyn_not_str(a) for a in dynamic_shapes} + return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes} if isinstance(dynamic_shapes, str): - return torch.export.Dim.DYNAMIC + return torch.export.Dim.DYNAMIC if default_value is None else default_value return dynamic_shapes From 40cadf5d54921f790ae321fcb8b7c6f155ae7ebb Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 30 Sep 2025 15:16:10 +0200 Subject: [PATCH 4/5] fix issues --- _doc/conf.py | 2 ++ .../ut_torch_export_patches/test_patch_torch.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/_doc/conf.py b/_doc/conf.py index 766d4c02..78361a79 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -114,6 +114,8 @@ def linkcode_resolve(domain, info): nitpicky = True # See also scikit-learn/scikit-learn#26761 nitpick_ignore = [ + ("py:class", "_DimHint"), + ("py:class", "KeyPath"), ("py:class", "ast.Node"), ("py:class", "dtype"), ("py:class", "False"), diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 18c3ad6a..a5cf9663 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -288,10 +288,14 @@ def forward(self, x, ind1, ind2): name="expected shape should be broadcastable to (< 2.9)", dynamic_shapes=dynamic_shapes, ): - with torch.fx.experimental._config.patch(backed_size_oblivious=True): - ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) - got = ep.module()(*inputs) - self.assertEqualArray(expected, got) + try: + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + except RuntimeError as e: + self.assertIn( + "Expected input at *args[2].shape[0] to be equal to 1, but got 1024", + str(e), + ) with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes): with torch_export_patches(): From 46b8402d7249e72059719d985688f26df877af0c Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 30 Sep 2025 16:26:51 +0200 Subject: [PATCH 5/5] fix unit test --- .../ut_tasks/test_tasks_image_to_video.py | 51 ++++++++++++++++++- _unittests/ut_torch_models/test_llm_phi2.py | 43 ++++++++++++++-- 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks_image_to_video.py b/_unittests/ut_tasks/test_tasks_image_to_video.py index dc40697f..7cd10bab 100644 --- a/_unittests/ut_tasks/test_tasks_image_to_video.py +++ b/_unittests/ut_tasks/test_tasks_image_to_video.py @@ -17,8 +17,8 @@ class TestTasksImageToVideo(ExtTestCase): @hide_stdout() @requires_diffusers("0.35") @requires_transformers("4.55") - @requires_torch("2.8.99") - def test_image_to_video(self): + @requires_torch("2.10.99") + def test_image_to_video_oblivious(self): kwargs = { "_diffusers_version": "0.34.0.dev0", "_class_name": "CosmosTransformer3DModel", @@ -63,6 +63,53 @@ def test_image_to_video(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) + @hide_stdout() + @requires_diffusers("0.35") + @requires_transformers("4.55") + @requires_torch("2.8.99") + def test_image_to_video_not_oblivious(self): + kwargs = { + "_diffusers_version": "0.34.0.dev0", + "_class_name": "CosmosTransformer3DModel", + "max_size": [128, 240, 240], + "text_embed_dim": 128, + "use_cache": True, + "in_channels": 3, + "out_channels": 16, + "num_layers": 2, + "model_type": "dia", + "patch_size": [1, 2, 2], + "rope_scale": [1.0, 3.0, 3.0], + "attention_head_dim": 16, + "mlp_ratio": 0.4, + "initializer_range": 0.02, + "num_attention_heads": 16, + "is_encoder_decoder": True, + "adaln_lora_dim": 16, + "concat_padding_mask": True, + "extra_pos_embed_type": None, + } + config = transformers.DiaConfig(**kwargs) + mid = "nvidia/Cosmos-Predict2-2B-Video2World" + data = get_untrained_model_with_inputs( + mid, + verbose=1, + add_second_input=True, + subfolder="transformer", + config=config, + inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80), + ) + self.assertEqual(data["task"], "image-to-video") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**inputs) + model(**data["inputs2"]) + with torch_export_patches( + patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1 + ): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_llm_phi2.py b/_unittests/ut_torch_models/test_llm_phi2.py index 77002654..1c2e9110 100644 --- a/_unittests/ut_torch_models/test_llm_phi2.py +++ b/_unittests/ut_torch_models/test_llm_phi2.py @@ -8,7 +8,10 @@ ) from onnx_diagnostic.torch_models.llms import get_phi2 from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches import ( + torch_export_patches, + register_additional_serialization_functions, +) from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -21,8 +24,8 @@ def test_get_phi2(self): @ignore_warnings(UserWarning) @requires_transformers("4.54") - @requires_torch("2.9.99") - def test_export_phi2_1_batch_size_1(self): + @requires_torch("2.10.99") + def test_export_phi2_1_batch_size_1_oblivious(self): # exporting vmap does not work data = get_phi2(num_hidden_layers=2, batch_size=1) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] @@ -38,6 +41,40 @@ def test_export_phi2_1_batch_size_1(self): ) assert ep + @ignore_warnings(UserWarning) + @requires_transformers("4.54") + @requires_torch("2.9.99") + def test_export_phi2_1_batch_size_1_not_oblivious(self): + # exporting vmap does not work + data = get_phi2(num_hidden_layers=2, batch_size=1) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + self.assertEqual(inputs["input_ids"].shape[0], 1) + self.assertEqual( + {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) + ) + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + assert ep + + @ignore_warnings(UserWarning) + @requires_transformers("4.54") + @requires_torch("2.12") + def test_export_phi2_1_batch_size_1_no_patch(self): + # exporting vmap does not work + data = get_phi2(num_hidden_layers=2, batch_size=1) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + self.assertEqual(inputs["input_ids"].shape[0], 1) + self.assertEqual( + {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) + ) + with register_additional_serialization_functions(patch_transformers=True): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + assert ep + @ignore_warnings(UserWarning) @requires_transformers("4.54") @requires_torch("2.9.99")