11import inspect
22import os
33import traceback
4- from typing import Any , Callable , Dict , List , Sequence , Tuple , Union
4+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
55import torch
66from torch ._subclasses .fake_tensor import FakeTensorMode
77
@@ -65,6 +65,8 @@ def patch__check_input_constraints_for_graph(
6565 verbose : int = 0 ,
6666) -> None :
6767 try :
68+ # PATCHED: catches exception and prints out the information instead of
69+ # stopping the conversion.
6870 return previous_function (input_placeholders , flat_args_with_path , range_constraints )
6971 except Exception as e :
7072 if not int (os .environ .get ("SKIP_SOLVE_CONSTRAINTS" , "1" )):
@@ -122,8 +124,7 @@ def patched_infer_size(a, b):
122124 if b1 or b2 or b3 :
123125 expandedSizes [i ] = sizeB if guard_size_oblivious (sizeA == 1 ) else sizeA
124126 else :
125- # In this case, the current implementation of torch fails (17/12/2024).
126- # Try model SmolLM.
127+ # PATCHED: generic case, the dimension is known, no need to assert
127128 expandedSizes [i ] = torch .sym_max (sizeA , sizeB )
128129 return tuple (expandedSizes )
129130
@@ -132,7 +133,11 @@ def patched__broadcast_shapes(*_shapes):
132133 """Patches ``torch._refs._broadcast_shapes``."""
133134 from functools import reduce
134135 from torch ._prims_common import IntLike
135- from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
136+ from torch .fx .experimental .symbolic_shapes import (
137+ guard_size_oblivious ,
138+ guard_or_false ,
139+ is_nested_int ,
140+ )
136141
137142 shapes = tuple (
138143 (x ,) if isinstance (x , IntLike ) else x for x in filter (lambda x : x is not None , _shapes )
@@ -142,17 +147,30 @@ def patched__broadcast_shapes(*_shapes):
142147 if len (shapes ) == 0 :
143148 return None
144149
145- # Type checking
146- # TODO: make common validations available as utils
147150 for shape in shapes :
148- assert isinstance (shape , Sequence )
151+ if not isinstance (shape , Sequence ):
152+ raise RuntimeError (
153+ "Input shapes should be of type ints, a tuple of ints, "
154+ "or a list of ints, got " ,
155+ shape ,
156+ )
149157
150158 # Computes common shape
151- common_shape = [ # List[Union[int, torch.SymInt]]
152- 1 ,
153- ] * reduce (max , (len (shape ) for shape in shapes ))
159+ common_shape = [1 ] * reduce (max , (len (shape ) for shape in shapes ))
154160 for _arg_idx , shape in enumerate (shapes ):
155161 for idx in range (- 1 , - 1 - len (shape ), - 1 ):
162+ if is_nested_int (shape [idx ]):
163+ # Broadcasting is allowed for (j0, 1) or (j0, j0);
164+ # not (j0, j1), (j0, 5), etc.
165+ if is_nested_int (common_shape [idx ]) and guard_or_false (
166+ shape [idx ] == common_shape [idx ]
167+ ):
168+ continue
169+ else :
170+ if guard_or_false (shape [idx ] == common_shape [idx ]):
171+ continue
172+ # PATCHED: two cases, if == for sure, no broadcast,
173+ # otherwise maybe broadcase with max(dimensions)
156174 if guard_size_oblivious (common_shape [idx ] == 1 ):
157175 if shape [idx ] < 0 :
158176 raise ValueError (
@@ -172,6 +190,7 @@ def _check_frozen(
172190 ) -> None :
173191 if self .frozen :
174192 self .counter ["ignored_backward_guard" ] += 1
193+ # PATCHED: raised an exception instead of logging.
175194 raise AssertionError (
176195 f"[patched_ShapeEnv] Ignored guard { expr } == { concrete_val } , "
177196 f"this could result in accuracy problems"
@@ -338,11 +357,13 @@ def _set_replacement(
338357 },
339358 )
340359
360+ # PATCHED: removed lines
341361 # if config.print_specializations:
342362 # self.log.warning(
343363 # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
344364 # )
345365 # self.log.debug("SPECIALIZATION", stack_info=True)
366+ # PATCHED: replaces logging by raising an exception
346367 assert msg != "range_refined_to_singleton" , (
347368 f"patched_ShapeEnv: A dynamic dimension becomes static! "
348369 f"a={ a !r} , tgt={ tgt !r} , msg={ msg !r} , tgt_bound={ tgt_bound } "
@@ -364,6 +385,7 @@ def _log_guard(
364385 self , prefix : str , g : "SympyBoolean" , forcing_spec : bool # noqa: F821
365386 ) -> None :
366387 self ._log_guard_remember (prefix = prefix , g = g , forcing_spec = forcing_spec )
388+ # PATCHED: removed
367389 # It happens too often to be relevant.
368390 # sloc, _maybe_extra_debug = self._get_stack_summary(True)
369391 # warnings.warn(
@@ -464,3 +486,87 @@ def wrapped(*args):
464486 return results
465487
466488 return wrapped
489+
490+
491+ def patched__constrain_user_specified_dimhint_range (
492+ symint : torch .SymInt ,
493+ hint : int ,
494+ dim : "_DimHint" , # noqa: F821
495+ range_constraints ,
496+ shape_env ,
497+ keypath : "KeyPath" , # noqa: F821
498+ i : Optional [int ] = None ,
499+ ) -> Optional [str ]:
500+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
501+ from torch ._export .non_strict_utils import is_int , int_oo , _DimHintType , ValueRanges
502+
503+ trace_vr = (
504+ range_constraints [symint .node .expr ]
505+ if not is_int (symint )
506+ else ValueRanges (int (symint ), int (symint ))
507+ )
508+ # warn on 0/1 specialization for Dim.AUTO; not an actual error
509+ # PATCHED: remove logging
510+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
511+ # pathstr = f"inputs{pytree.keystr(keypath)}"
512+ # if i is not None:
513+ # pathstr += f".shape[{i}]"
514+ # msg = (
515+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
516+ # f"with a sample input with hint = {hint}."
517+ # )
518+ # log.warning(msg)
519+
520+ try :
521+ user_vr = ValueRanges (
522+ lower = 0 if dim .min is None else dim .min ,
523+ upper = int_oo if dim .max is None else dim .max ,
524+ )
525+ if is_int (symint ):
526+ out_vr = trace_vr & user_vr
527+ else :
528+ range_constraints [symint .node .expr ] &= user_vr
529+ shape_env .var_to_range [symint .node ._expr ] &= user_vr
530+ out_vr = range_constraints [symint .node .expr ]
531+
532+ # check for Dim.DYNAMIC specializations; special case error message on 0/1
533+ if dim .type == _DimHintType .DYNAMIC and out_vr .is_singleton ():
534+ path = f"inputs{ torch .utils ._pytree .keystr (keypath )} "
535+ if i is not None :
536+ path += f".shape[{ i } ]"
537+ if (
538+ trace_vr .is_singleton ()
539+ and hint in (0 , 1 )
540+ # PATCHED: line removed
541+ # and not torch.fx.experimental._config.backed_size_oblivious
542+ ):
543+ return None
544+ # PATCHED: line removed
545+ # msg = (
546+ # f"- Received user-specified dim hint "
547+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
548+ # f"but export 0/1 specialized due to hint of "
549+ # f"{hint} for dimension {path}."
550+ # )
551+ else :
552+ msg = (
553+ f"- Received user-specified dim hint "
554+ f"Dim.DYNAMIC(min={ dim .min } , max={ dim .max } ), "
555+ f"but tracing inferred a static shape of "
556+ f"{ out_vr .lower } for dimension { path } ."
557+ )
558+ return msg
559+
560+ except torch .utils ._sympy .value_ranges .ValueRangeError :
561+ path = f"inputs{ torch .utils ._pytree .keystr (keypath )} "
562+ if i is not None :
563+ path += f".shape[{ i } ]"
564+ msg = (
565+ f"- Received user-specified min/max range of [{ dim .min } , { dim .max } ], "
566+ f"conflicting with the inferred min/max range of "
567+ f"[{ trace_vr .lower } , { trace_vr .upper } ], "
568+ f"for { path } ."
569+ )
570+ return msg
571+
572+ return None
0 commit comments