@@ -146,3 +146,183 @@ def patched__broadcast_shapes(*_shapes):
146146 common_shape [idx ] = torch .sym_max (common_shape [idx ], shape [idx ])
147147
148148 return common_shape
149+
150+
151+ class patched_ShapeEnv :
152+
153+ def _set_replacement (
154+ self , a : "sympy.Symbol" , tgt : "sympy.Expr" , msg : str # noqa: F821
155+ ) -> None :
156+ """
157+ Adds or updates a replacement for a symbol.
158+ Use this instead of `self.replacements[a] = tgt`.
159+ """
160+ if tgt == self .replacements .get (a , None ):
161+ return
162+
163+ if a in tgt .free_symbols :
164+ return
165+
166+ import sympy
167+ from torch ._logging import structured
168+ from torch .utils ._traceback import CapturedTraceback
169+ from torch ._logging import trace_structured
170+ from torch ._guards import TracingContext
171+ from torch .utils ._sympy .functions import FloorToInt , CeilToInt
172+ from torch .utils ._sympy .solve import try_solve
173+ from torch .fx .experimental .symbolic_shapes import (
174+ _is_supported_equivalence ,
175+ ValueRanges ,
176+ )
177+
178+ # Precondition: a == tgt
179+ assert isinstance (a , sympy .Symbol )
180+
181+ if self .allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence (tgt ):
182+ # continuing leads to placeholder shapes
183+ # having complex expressions that we can't resolve
184+ return
185+
186+ # Handles nested tensor symbolic variables which don't have
187+ # var_to_range bounds
188+ tgt_bound = None
189+ if a in self .var_to_range :
190+ src_bound = self .var_to_range [a ]
191+
192+ # First, refine the value range of a based on the computed value range
193+ # of tgt. This is always OK to do, even if we decide not to do the
194+ # substitution in the end. This might be a no-op, if a already has
195+ # a tighter bound
196+ tgt_bound = self .bound_sympy (tgt )
197+ self ._update_var_to_range (a , tgt_bound )
198+
199+ # Next, check if we can update the range of free symbols in tgt
200+ # based on the range in a. But only do it if:
201+ # - the source bound non-trivially improves over what we get out of
202+ # the existing bounds.
203+ # - the replacement is univariate and we can invert the tgt expression
204+ if not tgt_bound .issubset (src_bound ) and len (tgt .free_symbols ) == 1 :
205+ b = next (iter (tgt .free_symbols ))
206+ # Try to invert the equality
207+ r = try_solve (sympy .Eq (a , tgt ), b , floordiv_inequality = False )
208+ if r is not None :
209+ self .log .debug (
210+ "set_replacement: solve for %s in %s == %s gives %s" ,
211+ b ,
212+ a ,
213+ tgt ,
214+ r ,
215+ )
216+ # The solution here can be non-integral, for example, if
217+ # we have s0 = 2*s1, then s1 = s0/2. What we would like
218+ # to do is calculated the bounds in arbitrary precision,
219+ # and then requantize the bound to integers when we are
220+ # done.
221+ rat_b_bound = self .bound_sympy (r [1 ])
222+ b_bound = ValueRanges (
223+ CeilToInt (rat_b_bound .lower ), FloorToInt (rat_b_bound .upper )
224+ )
225+ self ._update_var_to_range (b , b_bound , self .var_to_range_sloc [a ])
226+ tgt_bound = self .bound_sympy (tgt )
227+ assert tgt_bound .issubset (
228+ src_bound
229+ ), f"{ tgt_bound = } not a subset of { src_bound = } "
230+
231+ # TODO: Should we propagate size-like-ness?
232+ #
233+ # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
234+ # to become size-like.
235+ #
236+ # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
237+ # propagate in this case, because what if u0 == 0, then u1 is negative
238+ # and clearly isn't a size. So, at minimum, any f(x) whose value
239+ # range isn't [0, inf] given x in [0, inf] cannot propagate
240+ # size-like-ness. But there are many situations where you could
241+ # imagine u1 is going to be size-like and actually you just didn't
242+ # have a refined enough value range on u0. Since even innocuous
243+ # looking arithmetic operations can destroy size-like-ness, it's
244+ # best to not propagate it at all and force the user to annotate it
245+ # as necessary.
246+ #
247+ # Compromise: we preserve size-like-ness only for exact equality
248+ # and nothing else.
249+ if a in self .size_like and isinstance (tgt , sympy .Symbol ):
250+ self .size_like .add (tgt )
251+ elif isinstance (tgt , sympy .Symbol ) and tgt in self .size_like :
252+ self .size_like .add (a )
253+
254+ # Now, decide if we will do the substitution.
255+ #
256+ # - If the source has a non-trivial range, only substitute if
257+ # we preserve this range. Note that we may have propagated
258+ # the src_range to free variables in tgt when tgt is univariate
259+ # and we could find an inverse, which helps us achieve this.
260+ # This ensures we never "forget" about user defined ranges,
261+ # even if they end up being defined on composite formulas
262+ # like s0 + s1.
263+ #
264+ # - If the variable is unbacked, only substitute if the substitution
265+ # would preserve the bounds also under size-like-ness conditions.
266+
267+ if not tgt_bound .issubset (src_bound ):
268+ self .log .debug (
269+ "skipped set_replacement %s = %s (%s) [%s not subset of %s]" ,
270+ a ,
271+ tgt ,
272+ msg ,
273+ tgt_bound ,
274+ src_bound ,
275+ )
276+ return
277+ elif a in self .size_like :
278+ tgt_bound_so = self .bound_sympy (tgt , size_oblivious = True )
279+ src_bound_so = self .bound_sympy (a , size_oblivious = True )
280+ if not tgt_bound_so .issubset (src_bound_so ):
281+ self .log .debug (
282+ "skipped set_replacement %s = %s (%s) "
283+ "[%s not subset of %s (size-oblivious conditions)]" ,
284+ a ,
285+ tgt ,
286+ msg ,
287+ tgt_bound_so ,
288+ src_bound_so ,
289+ )
290+ return
291+
292+ if isinstance (tgt , (sympy .Integer , sympy .Float )):
293+ # specializing to a constant, which is likely unexpected (unless
294+ # you specified dynamic=True)
295+
296+ user_tb = TracingContext .extract_stack ()
297+ trace_structured (
298+ "symbolic_shape_specialization" ,
299+ metadata_fn = lambda : {
300+ "symbol" : repr (a ),
301+ "sources" : [s .name () for s in self .var_to_sources .get (a , [])],
302+ "value" : repr (tgt ),
303+ "reason" : msg ,
304+ "stack" : structured .from_traceback (
305+ CapturedTraceback .extract (skip = 1 ).summary ()
306+ ),
307+ "user_stack" : (structured .from_traceback (user_tb ) if user_tb else None ),
308+ },
309+ )
310+
311+ # if config.print_specializations:
312+ # self.log.warning(
313+ # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
314+ # )
315+ # self.log.debug("SPECIALIZATION", stack_info=True)
316+ assert msg != "range_refined_to_singleton" , f"{ [a , tgt , msg , tgt_bound ]} "
317+ # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
318+ self .replacements [a ] = tgt
319+ # NB: the replacement may get refined, but the user will find the
320+ # FIRST one most useful (TODO: Maybe we could consider tracking all of
321+ # them)
322+ if a not in self .replacements_slocs :
323+ self .replacements_slocs [a ] = self ._get_sloc ()
324+ self ._update_version_counter ()
325+
326+ # When specializing 'a == tgt', the equality should be also conveyed to
327+ # Z3, in case an expression uses 'a'.
328+ self ._add_target_expr (sympy .Eq (a , tgt , evaluate = False ))
0 commit comments