Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.7.13
++++++

* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}

0.7.12
++++++

Expand Down
2 changes: 2 additions & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def linkcode_resolve(domain, info):
nitpicky = True
# See also scikit-learn/scikit-learn#26761
nitpick_ignore = [
("py:class", "_DimHint"),
("py:class", "KeyPath"),
("py:class", "ast.Node"),
("py:class", "dtype"),
("py:class", "False"),
Expand Down
51 changes: 49 additions & 2 deletions _unittests/ut_tasks/test_tasks_image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class TestTasksImageToVideo(ExtTestCase):
@hide_stdout()
@requires_diffusers("0.35")
@requires_transformers("4.55")
@requires_torch("2.8.99")
def test_image_to_video(self):
@requires_torch("2.10.99")
def test_image_to_video_oblivious(self):
kwargs = {
"_diffusers_version": "0.34.0.dev0",
"_class_name": "CosmosTransformer3DModel",
Expand Down Expand Up @@ -63,6 +63,53 @@ def test_image_to_video(self):
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)

@hide_stdout()
@requires_diffusers("0.35")
@requires_transformers("4.55")
@requires_torch("2.8.99")
def test_image_to_video_not_oblivious(self):
kwargs = {
"_diffusers_version": "0.34.0.dev0",
"_class_name": "CosmosTransformer3DModel",
"max_size": [128, 240, 240],
"text_embed_dim": 128,
"use_cache": True,
"in_channels": 3,
"out_channels": 16,
"num_layers": 2,
"model_type": "dia",
"patch_size": [1, 2, 2],
"rope_scale": [1.0, 3.0, 3.0],
"attention_head_dim": 16,
"mlp_ratio": 0.4,
"initializer_range": 0.02,
"num_attention_heads": 16,
"is_encoder_decoder": True,
"adaln_lora_dim": 16,
"concat_padding_mask": True,
"extra_pos_embed_type": None,
}
config = transformers.DiaConfig(**kwargs)
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
data = get_untrained_model_with_inputs(
mid,
verbose=1,
add_second_input=True,
subfolder="transformer",
config=config,
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
)
self.assertEqual(data["task"], "image-to-video")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
13 changes: 0 additions & 13 deletions _unittests/ut_torch_export_patches/test_patch_rewrite.py

This file was deleted.

5 changes: 5 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
rewrite_loop_for_square_mask,
)
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting


class TestPatchRewriting(ExtTestCase):
Expand Down Expand Up @@ -33,6 +34,10 @@ def apply_mask(mask, seq):
m2 = rewrite_loop_for_square_mask(mask, seq)
self.assertEqualArray(m1, m2)

def test_code_needing_rewriting(self):
res = code_needing_rewriting("BartModel")
self.assertEqual(len(res), 2)


if __name__ == "__main__":
unittest.main(verbosity=2)
82 changes: 81 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from typing import Callable
import torch
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
requires_torch,
requires_transformers,
has_torch,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestPatchPatchTorch(ExtTestCase):
Expand Down Expand Up @@ -236,6 +243,79 @@ def forward(self, x):
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
self.assertEqualArray(Model()(x), ep.module()(x))

def test_oblivious_for_dimension_01(self):
class Model(torch.nn.Module):
def forward(self, x, ind1, ind2):
return x[ind1, ind2]

inputs = (
torch.randn(2, 1024),
torch.tensor([[0, 1]], dtype=torch.int64).T,
torch.arange(1024, dtype=torch.int64),
)
model = Model()
expected = model(*inputs)

dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})
# ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN})

dynamic_shapes = use_dyn_not_str(dynamic_string)
with self.subTest(
name="export 0/1 specialized due to hint of 1 for dimension",
dynamic_shapes=dynamic_shapes,
):
try:
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
raise AssertionError("torch fixed that case")
except ValueError as e:
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))

dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO)
if has_torch("2.9"):
with self.subTest(
name="expected shape should be broadcastable to (>= 2.9)",
dynamic_shapes=dynamic_shapes,
):
try:
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
raise AssertionError("torch fixed that case")
except RuntimeError as e:
self.assertIn("expected shape should be broadcastable to", str(e))

if not has_torch("2.9"):
with self.subTest(
name="expected shape should be broadcastable to (< 2.9)",
dynamic_shapes=dynamic_shapes,
):
try:
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
except RuntimeError as e:
self.assertIn(
"Expected input at *args[2].shape[0] to be equal to 1, but got 1024",
str(e),
)

with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes):
with torch_export_patches():
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
got = ep.module()(*inputs)
self.assertEqualArray(expected, got)

if has_torch("2.11"):
# Missing PR https://github.com/pytorch/pytorch/pull/164225
# Needs more thinking about the patch to apply for this particular example.
with self.subTest(
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
):
with torch_export_patches(), torch.fx.experimental._config.patch(
backed_size_oblivious=True
):
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
got = ep.module()(*inputs)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
43 changes: 40 additions & 3 deletions _unittests/ut_torch_models/test_llm_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
)
from onnx_diagnostic.torch_models.llms import get_phi2
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches import (
torch_export_patches,
register_additional_serialization_functions,
)
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


Expand All @@ -21,8 +24,8 @@ def test_get_phi2(self):

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.9.99")
def test_export_phi2_1_batch_size_1(self):
@requires_torch("2.10.99")
def test_export_phi2_1_batch_size_1_oblivious(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2, batch_size=1)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
Expand All @@ -38,6 +41,40 @@ def test_export_phi2_1_batch_size_1(self):
)
assert ep

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.9.99")
def test_export_phi2_1_batch_size_1_not_oblivious(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2, batch_size=1)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
self.assertEqual(inputs["input_ids"].shape[0], 1)
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
assert ep

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.12")
def test_export_phi2_1_batch_size_1_no_patch(self):
# exporting vmap does not work
data = get_phi2(num_hidden_layers=2, batch_size=1)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
self.assertEqual(inputs["input_ids"].shape[0], 1)
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with register_additional_serialization_functions(patch_transformers=True):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
assert ep

@ignore_warnings(UserWarning)
@requires_transformers("4.54")
@requires_torch("2.9.99")
Expand Down
12 changes: 12 additions & 0 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def torch_export_patches(
patched_infer_size,
patched_vmap,
patched__broadcast_shapes,
patched__constrain_user_specified_dimhint_range,
_catch_produce_guards_and_solve_constraints,
patch__check_input_constraints_for_graph,
)
Expand Down Expand Up @@ -371,6 +372,14 @@ def torch_export_patches(
torch._refs._broadcast_shapes = patched__broadcast_shapes
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes

# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
f___constrain_user_specified_dimhint_range = (
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
)
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
patched__constrain_user_specified_dimhint_range
)

# torch._export.non_strict_utils.produce_guards_and_solve_constraints
if patch_torch and catch_constraints:
if verbose:
Expand Down Expand Up @@ -569,6 +578,9 @@ def torch_export_patches(
torch._subclasses.fake_impls.infer_size = f_infer_size
torch._refs._broadcast_shapes = f__broadcast_shapes
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
f___constrain_user_specified_dimhint_range
)

if verbose:
print("[torch_export_patches] restored pytorch functions")
Expand Down
16 changes: 10 additions & 6 deletions onnx_diagnostic/torch_export_patches/patch_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
return (), updated_kwargs, dynamic_shapes


def use_dyn_not_str(dynamic_shapes: Any) -> Any:
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
"""
Some functions returns dynamic shapes as string.
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
``default_value=torch.export.Dim.AUTO`` changes the default value.
"""
if isinstance(dynamic_shapes, list):
return [use_dyn_not_str(a) for a in dynamic_shapes]
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
if isinstance(dynamic_shapes, tuple):
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
if isinstance(dynamic_shapes, dict):
return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
return {
k: use_dyn_not_str(v, default_value=default_value)
for k, v in dynamic_shapes.items()
}
if isinstance(dynamic_shapes, set):
return {use_dyn_not_str(a) for a in dynamic_shapes}
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
if isinstance(dynamic_shapes, str):
return torch.export.Dim.DYNAMIC
return torch.export.Dim.DYNAMIC if default_value is None else default_value
return dynamic_shapes
Loading
Loading