Skip to content

Commit e090445

Browse files
committed
Merge branch 'main' of https://github.com/sdpython/onnx-diagnostic into gemma
2 parents 4a4b722 + 7158b2f commit e090445

File tree

17 files changed

+414
-67
lines changed

17 files changed

+414
-67
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
name: to-${{ matrix.torch }}-tr-${{ matrix.transformers }}-ci ${{ matrix.os }}-${{ matrix.python }}
1414
runs-on: ${{ matrix.os }}
1515
strategy:
16+
fail-fast: false
1617
matrix:
1718
os: [ubuntu-latest]
1819
python: ['3.10', '3.11', '3.12', '3.13']

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.7.13
5+
++++++
6+
7+
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
8+
49
0.7.12
510
++++++
611

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def linkcode_resolve(domain, info):
114114
nitpicky = True
115115
# See also scikit-learn/scikit-learn#26761
116116
nitpick_ignore = [
117+
("py:class", "_DimHint"),
118+
("py:class", "KeyPath"),
117119
("py:class", "ast.Node"),
118120
("py:class", "dtype"),
119121
("py:class", "False"),

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ The function replaces dynamic dimensions defined as strings by
239239
Older versions
240240
==============
241241

242+
* `0.7.13 <../v0.7.13/index.html>`_
242243
* `0.7.12 <../v0.7.12/index.html>`_
243244
* `0.7.11 <../v0.7.11/index.html>`_
244245
* `0.6.3 <../v0.6.3/index.html>`_

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def test_replace_string_by_dynamic(self):
273273
"_DimHint(type=<_DimHintType.DYNAMIC:3>,min=None,max=None,_factory=True)",
274274
"DYN",
275275
)
276+
.replace("DimHint(DYNAMIC)", "DYN")
276277
)
277278
self.assertEqual(
278279
"{'input_ids':{0:DYN,1:DYN},'attention_mask':({0:DYN,1:DYN},),'position_ids':[{0:DYN,1:DYN}]}",

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class TestTasksImageToVideo(ExtTestCase):
1717
@hide_stdout()
1818
@requires_diffusers("0.35")
1919
@requires_transformers("4.55")
20-
@requires_torch("2.8.99")
21-
def test_image_to_video(self):
20+
@requires_torch("2.10.99")
21+
def test_image_to_video_oblivious(self):
2222
kwargs = {
2323
"_diffusers_version": "0.34.0.dev0",
2424
"_class_name": "CosmosTransformer3DModel",
@@ -63,6 +63,53 @@ def test_image_to_video(self):
6363
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
6464
)
6565

66+
@hide_stdout()
67+
@requires_diffusers("0.35")
68+
@requires_transformers("4.55")
69+
@requires_torch("2.8.99")
70+
def test_image_to_video_not_oblivious(self):
71+
kwargs = {
72+
"_diffusers_version": "0.34.0.dev0",
73+
"_class_name": "CosmosTransformer3DModel",
74+
"max_size": [128, 240, 240],
75+
"text_embed_dim": 128,
76+
"use_cache": True,
77+
"in_channels": 3,
78+
"out_channels": 16,
79+
"num_layers": 2,
80+
"model_type": "dia",
81+
"patch_size": [1, 2, 2],
82+
"rope_scale": [1.0, 3.0, 3.0],
83+
"attention_head_dim": 16,
84+
"mlp_ratio": 0.4,
85+
"initializer_range": 0.02,
86+
"num_attention_heads": 16,
87+
"is_encoder_decoder": True,
88+
"adaln_lora_dim": 16,
89+
"concat_padding_mask": True,
90+
"extra_pos_embed_type": None,
91+
}
92+
config = transformers.DiaConfig(**kwargs)
93+
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
94+
data = get_untrained_model_with_inputs(
95+
mid,
96+
verbose=1,
97+
add_second_input=True,
98+
subfolder="transformer",
99+
config=config,
100+
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
101+
)
102+
self.assertEqual(data["task"], "image-to-video")
103+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
104+
model(**inputs)
105+
model(**data["inputs2"])
106+
with torch_export_patches(
107+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
108+
):
109+
torch.export.export(
110+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
111+
)
112+
66113

67114
if __name__ == "__main__":
68115
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_rewrite.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

_unittests/ut_torch_export_patches/test_patch_rewriting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
44
rewrite_loop_for_square_mask,
55
)
6+
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
67

78

89
class TestPatchRewriting(ExtTestCase):
@@ -33,6 +34,10 @@ def apply_mask(mask, seq):
3334
m2 = rewrite_loop_for_square_mask(mask, seq)
3435
self.assertEqualArray(m1, m2)
3536

37+
def test_code_needing_rewriting(self):
38+
res = code_needing_rewriting("BartModel")
39+
self.assertEqual(len(res), 2)
40+
3641

3742
if __name__ == "__main__":
3843
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from typing import Callable
33
import torch
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
requires_torch,
8+
requires_transformers,
9+
has_torch,
10+
)
11+
from onnx_diagnostic.torch_export_patches import torch_export_patches
12+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
613

714

815
class TestPatchPatchTorch(ExtTestCase):
@@ -236,6 +243,79 @@ def forward(self, x):
236243
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
237244
self.assertEqualArray(Model()(x), ep.module()(x))
238245

246+
def test_oblivious_for_dimension_01(self):
247+
class Model(torch.nn.Module):
248+
def forward(self, x, ind1, ind2):
249+
return x[ind1, ind2]
250+
251+
inputs = (
252+
torch.randn(2, 1024),
253+
torch.tensor([[0, 1]], dtype=torch.int64).T,
254+
torch.arange(1024, dtype=torch.int64),
255+
)
256+
model = Model()
257+
expected = model(*inputs)
258+
259+
dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})
260+
# ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN})
261+
262+
dynamic_shapes = use_dyn_not_str(dynamic_string)
263+
with self.subTest(
264+
name="export 0/1 specialized due to hint of 1 for dimension",
265+
dynamic_shapes=dynamic_shapes,
266+
):
267+
try:
268+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
269+
raise AssertionError("torch fixed that case")
270+
except ValueError as e:
271+
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))
272+
273+
dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO)
274+
if has_torch("2.9"):
275+
with self.subTest(
276+
name="expected shape should be broadcastable to (>= 2.9)",
277+
dynamic_shapes=dynamic_shapes,
278+
):
279+
try:
280+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
281+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
282+
raise AssertionError("torch fixed that case")
283+
except RuntimeError as e:
284+
self.assertIn("expected shape should be broadcastable to", str(e))
285+
286+
if not has_torch("2.9"):
287+
with self.subTest(
288+
name="expected shape should be broadcastable to (< 2.9)",
289+
dynamic_shapes=dynamic_shapes,
290+
):
291+
try:
292+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
293+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
294+
except RuntimeError as e:
295+
self.assertIn(
296+
"Expected input at *args[2].shape[0] to be equal to 1, but got 1024",
297+
str(e),
298+
)
299+
300+
with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes):
301+
with torch_export_patches():
302+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
303+
got = ep.module()(*inputs)
304+
self.assertEqualArray(expected, got)
305+
306+
if has_torch("2.11"):
307+
# Missing PR https://github.com/pytorch/pytorch/pull/164225
308+
# Needs more thinking about the patch to apply for this particular example.
309+
with self.subTest(
310+
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
311+
):
312+
with torch_export_patches(), torch.fx.experimental._config.patch(
313+
backed_size_oblivious=True
314+
):
315+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
316+
got = ep.module()(*inputs)
317+
self.assertEqualArray(expected, got)
318+
239319

240320
if __name__ == "__main__":
241321
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
class TestHuggingFaceHubApi(ExtTestCase):
3030

31+
@unittest.skip("https://github.com/sdpython/onnx-diagnostic/issues/242")
3132
@requires_transformers("4.50") # we limit to some versions of the CI
3233
@requires_torch("2.7")
3334
@ignore_errors(OSError) # connectivity issues

0 commit comments

Comments
 (0)