Skip to content

Commit c7e3ea1

Browse files
committed
Add a patch for dimension in 0/1
1 parent 7e35e7f commit c7e3ea1

File tree

5 files changed

+171
-23
lines changed

5 files changed

+171
-23
lines changed

_unittests/ut_torch_export_patches/test_patch_rewrite.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

_unittests/ut_torch_export_patches/test_patch_rewriting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
44
rewrite_loop_for_square_mask,
55
)
6+
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
67

78

89
class TestPatchRewriting(ExtTestCase):
@@ -33,6 +34,10 @@ def apply_mask(mask, seq):
3334
m2 = rewrite_loop_for_square_mask(mask, seq)
3435
self.assertEqualArray(m1, m2)
3536

37+
def test_code_needing_rewriting(self):
38+
res = code_needing_rewriting("BartModel")
39+
self.assertEqual(len(res), 2)
40+
3641

3742
if __name__ == "__main__":
3843
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import torch
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
55
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
6+
from onnx_diagnostic.torch_export_patches import torch_export_patches
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
68

79

810
class TestPatchPatchTorch(ExtTestCase):
@@ -236,6 +238,42 @@ def forward(self, x):
236238
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
237239
self.assertEqualArray(Model()(x), ep.module()(x))
238240

241+
def test_oblivious_for_dimension_01(self):
242+
class Model(torch.nn.Module):
243+
def forward(self, x, ind1, ind2):
244+
return x[ind1, ind2]
245+
246+
inputs = (
247+
torch.randn(2, 1024),
248+
torch.tensor([[0, 1]], dtype=torch.int64).T,
249+
torch.arange(1024, dtype=torch.int64),
250+
)
251+
model = Model()
252+
expected = model(*inputs)
253+
254+
dynamic_string = ({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})
255+
dynamic_shapes = use_dyn_not_str(dynamic_string)
256+
with self.subTest(name="export 0/1 specialized due to hint of 1 for dimension"):
257+
try:
258+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
259+
raise AssertionError("torch fixed that case")
260+
except ValueError as e:
261+
self.assertIn("export 0/1 specialized due to hint of 1 for dimension", str(e))
262+
263+
with self.subTest(name="expected shape should be broadcastable to"):
264+
try:
265+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
266+
torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
267+
raise AssertionError("torch fixed that case")
268+
except RuntimeError as e:
269+
self.assertIn("expected shape should be broadcastable to", str(e))
270+
271+
with self.subTest(name="patch for 0/1"):
272+
with torch_export_patches():
273+
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
274+
got = ep.module()(*inputs)
275+
self.assertEqualArray(expected, got)
276+
239277

240278
if __name__ == "__main__":
241279
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def torch_export_patches(
341341
patched_infer_size,
342342
patched_vmap,
343343
patched__broadcast_shapes,
344+
patched__constrain_user_specified_dimhint_range,
344345
_catch_produce_guards_and_solve_constraints,
345346
patch__check_input_constraints_for_graph,
346347
)
@@ -371,6 +372,14 @@ def torch_export_patches(
371372
torch._refs._broadcast_shapes = patched__broadcast_shapes
372373
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
373374

375+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
376+
f___constrain_user_specified_dimhint_range = (
377+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
378+
)
379+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
380+
patched__constrain_user_specified_dimhint_range
381+
)
382+
374383
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
375384
if patch_torch and catch_constraints:
376385
if verbose:
@@ -569,6 +578,9 @@ def torch_export_patches(
569578
torch._subclasses.fake_impls.infer_size = f_infer_size
570579
torch._refs._broadcast_shapes = f__broadcast_shapes
571580
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
581+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
582+
f___constrain_user_specified_dimhint_range
583+
)
572584

573585
if verbose:
574586
print("[torch_export_patches] restored pytorch functions")

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import os
33
import traceback
4-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
55
import torch
66
from torch._subclasses.fake_tensor import FakeTensorMode
77

@@ -65,6 +65,8 @@ def patch__check_input_constraints_for_graph(
6565
verbose: int = 0,
6666
) -> None:
6767
try:
68+
# PATCHED: catches exception and prints out the information instead of
69+
# stopping the conversion.
6870
return previous_function(input_placeholders, flat_args_with_path, range_constraints)
6971
except Exception as e:
7072
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
@@ -122,8 +124,7 @@ def patched_infer_size(a, b):
122124
if b1 or b2 or b3:
123125
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
124126
else:
125-
# In this case, the current implementation of torch fails (17/12/2024).
126-
# Try model SmolLM.
127+
# PATCHED: generic case, the dimension is known, no need to assert
127128
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
128129
return tuple(expandedSizes)
129130

@@ -132,7 +133,11 @@ def patched__broadcast_shapes(*_shapes):
132133
"""Patches ``torch._refs._broadcast_shapes``."""
133134
from functools import reduce
134135
from torch._prims_common import IntLike
135-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
136+
from torch.fx.experimental.symbolic_shapes import (
137+
guard_size_oblivious,
138+
guard_or_false,
139+
is_nested_int,
140+
)
136141

137142
shapes = tuple(
138143
(x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
@@ -142,17 +147,30 @@ def patched__broadcast_shapes(*_shapes):
142147
if len(shapes) == 0:
143148
return None
144149

145-
# Type checking
146-
# TODO: make common validations available as utils
147150
for shape in shapes:
148-
assert isinstance(shape, Sequence)
151+
if not isinstance(shape, Sequence):
152+
raise RuntimeError(
153+
"Input shapes should be of type ints, a tuple of ints, "
154+
"or a list of ints, got ",
155+
shape,
156+
)
149157

150158
# Computes common shape
151-
common_shape = [ # List[Union[int, torch.SymInt]]
152-
1,
153-
] * reduce(max, (len(shape) for shape in shapes))
159+
common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
154160
for _arg_idx, shape in enumerate(shapes):
155161
for idx in range(-1, -1 - len(shape), -1):
162+
if is_nested_int(shape[idx]):
163+
# Broadcasting is allowed for (j0, 1) or (j0, j0);
164+
# not (j0, j1), (j0, 5), etc.
165+
if is_nested_int(common_shape[idx]) and guard_or_false(
166+
shape[idx] == common_shape[idx]
167+
):
168+
continue
169+
else:
170+
if guard_or_false(shape[idx] == common_shape[idx]):
171+
continue
172+
# PATCHED: two cases, if == for sure, no broadcast,
173+
# otherwise maybe broadcase with max(dimensions)
156174
if guard_size_oblivious(common_shape[idx] == 1):
157175
if shape[idx] < 0:
158176
raise ValueError(
@@ -172,6 +190,7 @@ def _check_frozen(
172190
) -> None:
173191
if self.frozen:
174192
self.counter["ignored_backward_guard"] += 1
193+
# PATCHED: raised an exception instead of logging.
175194
raise AssertionError(
176195
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
177196
f"this could result in accuracy problems"
@@ -338,11 +357,13 @@ def _set_replacement(
338357
},
339358
)
340359

360+
# PATCHED: removed lines
341361
# if config.print_specializations:
342362
# self.log.warning(
343363
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
344364
# )
345365
# self.log.debug("SPECIALIZATION", stack_info=True)
366+
# PATCHED: replaces logging by raising an exception
346367
assert msg != "range_refined_to_singleton", (
347368
f"patched_ShapeEnv: A dynamic dimension becomes static! "
348369
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
@@ -364,6 +385,7 @@ def _log_guard(
364385
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
365386
) -> None:
366387
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
388+
# PATCHED: removed
367389
# It happens too often to be relevant.
368390
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
369391
# warnings.warn(
@@ -464,3 +486,87 @@ def wrapped(*args):
464486
return results
465487

466488
return wrapped
489+
490+
491+
def patched__constrain_user_specified_dimhint_range(
492+
symint: torch.SymInt,
493+
hint: int,
494+
dim: "_DimHint", # noqa: F821
495+
range_constraints,
496+
shape_env,
497+
keypath: "KeyPath", # noqa: F821
498+
i: Optional[int] = None,
499+
) -> Optional[str]:
500+
"""Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
501+
from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
502+
503+
trace_vr = (
504+
range_constraints[symint.node.expr]
505+
if not is_int(symint)
506+
else ValueRanges(int(symint), int(symint))
507+
)
508+
# warn on 0/1 specialization for Dim.AUTO; not an actual error
509+
# PATCHED: remove logging
510+
# if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
511+
# pathstr = f"inputs{pytree.keystr(keypath)}"
512+
# if i is not None:
513+
# pathstr += f".shape[{i}]"
514+
# msg = (
515+
# f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
516+
# f"with a sample input with hint = {hint}."
517+
# )
518+
# log.warning(msg)
519+
520+
try:
521+
user_vr = ValueRanges(
522+
lower=0 if dim.min is None else dim.min,
523+
upper=int_oo if dim.max is None else dim.max,
524+
)
525+
if is_int(symint):
526+
out_vr = trace_vr & user_vr
527+
else:
528+
range_constraints[symint.node.expr] &= user_vr
529+
shape_env.var_to_range[symint.node._expr] &= user_vr
530+
out_vr = range_constraints[symint.node.expr]
531+
532+
# check for Dim.DYNAMIC specializations; special case error message on 0/1
533+
if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
534+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
535+
if i is not None:
536+
path += f".shape[{i}]"
537+
if (
538+
trace_vr.is_singleton()
539+
and hint in (0, 1)
540+
# PATCHED: line removed
541+
# and not torch.fx.experimental._config.backed_size_oblivious
542+
):
543+
return None
544+
# PATCHED: line removed
545+
# msg = (
546+
# f"- Received user-specified dim hint "
547+
# f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
548+
# f"but export 0/1 specialized due to hint of "
549+
# f"{hint} for dimension {path}."
550+
# )
551+
else:
552+
msg = (
553+
f"- Received user-specified dim hint "
554+
f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
555+
f"but tracing inferred a static shape of "
556+
f"{out_vr.lower} for dimension {path}."
557+
)
558+
return msg
559+
560+
except torch.utils._sympy.value_ranges.ValueRangeError:
561+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
562+
if i is not None:
563+
path += f".shape[{i}]"
564+
msg = (
565+
f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
566+
f"conflicting with the inferred min/max range of "
567+
f"[{trace_vr.lower}, {trace_vr.upper}], "
568+
f"for {path}."
569+
)
570+
return msg
571+
572+
return None

0 commit comments

Comments
 (0)