Skip to content

Commit 15ab96e

Browse files
committed
auto
1 parent bccdc0c commit 15ab96e

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from typing import Callable
33
import torch
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
requires_torch,
8+
requires_transformers,
9+
has_torch,
10+
)
611
from onnx_diagnostic.torch_export_patches import torch_export_patches
712
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
813

@@ -252,28 +257,61 @@ def forward(self, x, ind1, ind2):
252257
expected = model(*inputs)
253258

254259
dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})
260+
# ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {0: DYN})
261+
255262
dynamic_shapes = use_dyn_not_str(dynamic_string)
256-
with self.subTest(name="export 0/1 specialized due to hint of 1 for dimension"):
263+
with self.subTest(
264+
name="export 0/1 specialized due to hint of 1 for dimension",
265+
dynamic_shapes=dynamic_shapes,
266+
):
257267
try:
258268
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
259269
raise AssertionError("torch fixed that case")
260270
except ValueError as e:
261271
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))
262272

263-
with self.subTest(name="expected shape should be broadcastable to"):
264-
try:
273+
dynamic_shapes = use_dyn_not_str(dynamic_string, torch.export.Dim.AUTO)
274+
if has_torch("2.9"):
275+
with self.subTest(
276+
name="expected shape should be broadcastable to (>= 2.9)",
277+
dynamic_shapes=dynamic_shapes,
278+
):
279+
try:
280+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
281+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
282+
raise AssertionError("torch fixed that case")
283+
except RuntimeError as e:
284+
self.assertIn("expected shape should be broadcastable to", str(e))
285+
286+
if not has_torch("2.9"):
287+
with self.subTest(
288+
name="expected shape should be broadcastable to (< 2.9)",
289+
dynamic_shapes=dynamic_shapes,
290+
):
265291
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
266-
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
267-
raise AssertionError("torch fixed that case")
268-
except RuntimeError as e:
269-
self.assertIn("expected shape should be broadcastable to", str(e))
292+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
293+
got = ep.module()(*inputs)
294+
self.assertEqualArray(expected, got)
270295

271-
with self.subTest(name="patch for 0/1"):
296+
with self.subTest(name="patch for 0/1", dynamic_shapes=dynamic_shapes):
272297
with torch_export_patches():
273298
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
274299
got = ep.module()(*inputs)
275300
self.assertEqualArray(expected, got)
276301

302+
if has_torch("2.11"):
303+
# Missing PR https://github.com/pytorch/pytorch/pull/164225
304+
# Needs more thinking about the patch to apply for this particular example.
305+
with self.subTest(
306+
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
307+
):
308+
with torch_export_patches(), torch.fx.experimental._config.patch(
309+
backed_size_oblivious=True
310+
):
311+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
312+
got = ep.module()(*inputs)
313+
self.assertEqualArray(expected, got)
314+
277315

278316
if __name__ == "__main__":
279317
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/patch_inputs.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
189189
return (), updated_kwargs, dynamic_shapes
190190

191191

192-
def use_dyn_not_str(dynamic_shapes: Any) -> Any:
192+
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
193193
"""
194194
Some functions returns dynamic shapes as string.
195195
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
196+
``default_value=torch.export.Dim.AUTO`` changes the default value.
196197
"""
197198
if isinstance(dynamic_shapes, list):
198-
return [use_dyn_not_str(a) for a in dynamic_shapes]
199+
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
199200
if isinstance(dynamic_shapes, tuple):
200-
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
201+
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
201202
if isinstance(dynamic_shapes, dict):
202-
return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
203+
return {
204+
k: use_dyn_not_str(v, default_value=default_value)
205+
for k, v in dynamic_shapes.items()
206+
}
203207
if isinstance(dynamic_shapes, set):
204-
return {use_dyn_not_str(a) for a in dynamic_shapes}
208+
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
205209
if isinstance(dynamic_shapes, str):
206-
return torch.export.Dim.DYNAMIC
210+
return torch.export.Dim.DYNAMIC if default_value is None else default_value
207211
return dynamic_shapes

0 commit comments

Comments
 (0)