Skip to content

Commit 7158b2f

Browse files
authored
Add a patch for dimension in 0/1 (#244)
* Add a patch for dimension in 0/1 * doc * auto * fix issues * fix unit test
1 parent 7e35e7f commit 7158b2f

File tree

10 files changed

+317
-35
lines changed

10 files changed

+317
-35
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.13
55
++++++
66

7+
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
8+
79
0.7.12
810
++++++
911

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def linkcode_resolve(domain, info):
114114
nitpicky = True
115115
# See also scikit-learn/scikit-learn#26761
116116
nitpick_ignore = [
117+
("py:class", "_DimHint"),
118+
("py:class", "KeyPath"),
117119
("py:class", "ast.Node"),
118120
("py:class", "dtype"),
119121
("py:class", "False"),

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class TestTasksImageToVideo(ExtTestCase):
1717
@hide_stdout()
1818
@requires_diffusers("0.35")
1919
@requires_transformers("4.55")
20-
@requires_torch("2.8.99")
21-
def test_image_to_video(self):
20+
@requires_torch("2.10.99")
21+
def test_image_to_video_oblivious(self):
2222
kwargs = {
2323
"_diffusers_version": "0.34.0.dev0",
2424
"_class_name": "CosmosTransformer3DModel",
@@ -63,6 +63,53 @@ def test_image_to_video(self):
6363
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
6464
)
6565

66+
@hide_stdout()
67+
@requires_diffusers("0.35")
68+
@requires_transformers("4.55")
69+
@requires_torch("2.8.99")
70+
def test_image_to_video_not_oblivious(self):
71+
kwargs = {
72+
"_diffusers_version": "0.34.0.dev0",
73+
"_class_name": "CosmosTransformer3DModel",
74+
"max_size": [128, 240, 240],
75+
"text_embed_dim": 128,
76+
"use_cache": True,
77+
"in_channels": 3,
78+
"out_channels": 16,
79+
"num_layers": 2,
80+
"model_type": "dia",
81+
"patch_size": [1, 2, 2],
82+
"rope_scale": [1.0, 3.0, 3.0],
83+
"attention_head_dim": 16,
84+
"mlp_ratio": 0.4,
85+
"initializer_range": 0.02,
86+
"num_attention_heads": 16,
87+
"is_encoder_decoder": True,
88+
"adaln_lora_dim": 16,
89+
"concat_padding_mask": True,
90+
"extra_pos_embed_type": None,
91+
}
92+
config = transformers.DiaConfig(**kwargs)
93+
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
94+
data = get_untrained_model_with_inputs(
95+
mid,
96+
verbose=1,
97+
add_second_input=True,
98+
subfolder="transformer",
99+
config=config,
100+
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
101+
)
102+
self.assertEqual(data["task"], "image-to-video")
103+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
104+
model(**inputs)
105+
model(**data["inputs2"])
106+
with torch_export_patches(
107+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
108+
):
109+
torch.export.export(
110+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
111+
)
112+
66113

67114
if __name__ == "__main__":
68115
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_rewrite.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

_unittests/ut_torch_export_patches/test_patch_rewriting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
44
rewrite_loop_for_square_mask,
55
)
6+
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
67

78

89
class TestPatchRewriting(ExtTestCase):
@@ -33,6 +34,10 @@ def apply_mask(mask, seq):
3334
m2 = rewrite_loop_for_square_mask(mask, seq)
3435
self.assertEqualArray(m1, m2)
3536

37+
def test_code_needing_rewriting(self):
38+
res = code_needing_rewriting("BartModel")
39+
self.assertEqual(len(res), 2)
40+
3641

3742
if __name__ == "__main__":
3843
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from typing import Callable
33
import torch
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
requires_torch,
8+
requires_transformers,
9+
has_torch,
10+
)
11+
from onnx_diagnostic.torch_export_patches import torch_export_patches
12+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
613

714

815
class TestPatchPatchTorch(ExtTestCase):
@@ -236,6 +243,79 @@ def forward(self, x):
236243
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
237244
self.assertEqualArray(Model()(x), ep.module()(x))
238245

246+
def test_oblivious_for_dimension_01(self):
247+
class Model(torch.nn.Module):
248+
def forward(self, x, ind1, ind2):
249+
return x[ind1, ind2]
250+
251+
inputs = (
252+
torch.randn(2, 1024),
253+
torch.tensor([[0, 1]], dtype=torch.int64).T,
254+
torch.arange(1024, dtype=torch.int64),
255+
)
256+
model = Model()
257+
expected = model(*inputs)
258+
259+
dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})
260+
# ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN})
261+
262+
dynamic_shapes = use_dyn_not_str(dynamic_string)
263+
with self.subTest(
264+
name="export 0/1 specialized due to hint of 1 for dimension",
265+
dynamic_shapes=dynamic_shapes,
266+
):
267+
try:
268+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
269+
raise AssertionError("torch fixed that case")
270+
except ValueError as e:
271+
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))
272+
273+
dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO)
274+
if has_torch("2.9"):
275+
with self.subTest(
276+
name="expected shape should be broadcastable to (>= 2.9)",
277+
dynamic_shapes=dynamic_shapes,
278+
):
279+
try:
280+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
281+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
282+
raise AssertionError("torch fixed that case")
283+
except RuntimeError as e:
284+
self.assertIn("expected shape should be broadcastable to", str(e))
285+
286+
if not has_torch("2.9"):
287+
with self.subTest(
288+
name="expected shape should be broadcastable to (< 2.9)",
289+
dynamic_shapes=dynamic_shapes,
290+
):
291+
try:
292+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
293+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
294+
except RuntimeError as e:
295+
self.assertIn(
296+
"Expected input at *args[2].shape[0] to be equal to 1, but got 1024",
297+
str(e),
298+
)
299+
300+
with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes):
301+
with torch_export_patches():
302+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
303+
got = ep.module()(*inputs)
304+
self.assertEqualArray(expected, got)
305+
306+
if has_torch("2.11"):
307+
# Missing PR https://github.com/pytorch/pytorch/pull/164225
308+
# Needs more thinking about the patch to apply for this particular example.
309+
with self.subTest(
310+
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
311+
):
312+
with torch_export_patches(), torch.fx.experimental._config.patch(
313+
backed_size_oblivious=True
314+
):
315+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
316+
got = ep.module()(*inputs)
317+
self.assertEqualArray(expected, got)
318+
239319

240320
if __name__ == "__main__":
241321
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
)
99
from onnx_diagnostic.torch_models.llms import get_phi2
1010
from onnx_diagnostic.helpers import string_type
11-
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches import (
12+
torch_export_patches,
13+
register_additional_serialization_functions,
14+
)
1215
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1316

1417

@@ -21,8 +24,8 @@ def test_get_phi2(self):
2124

2225
@ignore_warnings(UserWarning)
2326
@requires_transformers("4.54")
24-
@requires_torch("2.9.99")
25-
def test_export_phi2_1_batch_size_1(self):
27+
@requires_torch("2.10.99")
28+
def test_export_phi2_1_batch_size_1_oblivious(self):
2629
# exporting vmap does not work
2730
data = get_phi2(num_hidden_layers=2, batch_size=1)
2831
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -38,6 +41,40 @@ def test_export_phi2_1_batch_size_1(self):
3841
)
3942
assert ep
4043

44+
@ignore_warnings(UserWarning)
45+
@requires_transformers("4.54")
46+
@requires_torch("2.9.99")
47+
def test_export_phi2_1_batch_size_1_not_oblivious(self):
48+
# exporting vmap does not work
49+
data = get_phi2(num_hidden_layers=2, batch_size=1)
50+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
51+
self.assertEqual(inputs["input_ids"].shape[0], 1)
52+
self.assertEqual(
53+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
54+
)
55+
with torch_export_patches(patch_transformers=True):
56+
ep = torch.export.export(
57+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
58+
)
59+
assert ep
60+
61+
@ignore_warnings(UserWarning)
62+
@requires_transformers("4.54")
63+
@requires_torch("2.12")
64+
def test_export_phi2_1_batch_size_1_no_patch(self):
65+
# exporting vmap does not work
66+
data = get_phi2(num_hidden_layers=2, batch_size=1)
67+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
68+
self.assertEqual(inputs["input_ids"].shape[0], 1)
69+
self.assertEqual(
70+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
71+
)
72+
with register_additional_serialization_functions(patch_transformers=True):
73+
ep = torch.export.export(
74+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
75+
)
76+
assert ep
77+
4178
@ignore_warnings(UserWarning)
4279
@requires_transformers("4.54")
4380
@requires_torch("2.9.99")

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def torch_export_patches(
341341
patched_infer_size,
342342
patched_vmap,
343343
patched__broadcast_shapes,
344+
patched__constrain_user_specified_dimhint_range,
344345
_catch_produce_guards_and_solve_constraints,
345346
patch__check_input_constraints_for_graph,
346347
)
@@ -371,6 +372,14 @@ def torch_export_patches(
371372
torch._refs._broadcast_shapes = patched__broadcast_shapes
372373
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
373374

375+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
376+
f___constrain_user_specified_dimhint_range = (
377+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
378+
)
379+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
380+
patched__constrain_user_specified_dimhint_range
381+
)
382+
374383
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
375384
if patch_torch and catch_constraints:
376385
if verbose:
@@ -569,6 +578,9 @@ def torch_export_patches(
569578
torch._subclasses.fake_impls.infer_size = f_infer_size
570579
torch._refs._broadcast_shapes = f__broadcast_shapes
571580
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
581+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
582+
f___constrain_user_specified_dimhint_range
583+
)
572584

573585
if verbose:
574586
print("[torch_export_patches] restored pytorch functions")

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
189189
return (), updated_kwargs, dynamic_shapes
190190

191191

192-
def use_dyn_not_str(dynamic_shapes: Any) -> Any:
192+
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
193193
"""
194194
Some functions returns dynamic shapes as string.
195195
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
196+
``default_value=torch.export.Dim.AUTO`` changes the default value.
196197
"""
197198
if isinstance(dynamic_shapes, list):
198-
return [use_dyn_not_str(a) for a in dynamic_shapes]
199+
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
199200
if isinstance(dynamic_shapes, tuple):
200-
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
201+
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
201202
if isinstance(dynamic_shapes, dict):
202-
return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
203+
return {
204+
k: use_dyn_not_str(v, default_value=default_value)
205+
for k, v in dynamic_shapes.items()
206+
}
203207
if isinstance(dynamic_shapes, set):
204-
return {use_dyn_not_str(a) for a in dynamic_shapes}
208+
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
205209
if isinstance(dynamic_shapes, str):
206-
return torch.export.Dim.DYNAMIC
210+
return torch.export.Dim.DYNAMIC if default_value is None else default_value
207211
return dynamic_shapes

0 commit comments

Comments
 (0)