11# mypy: allow-untyped-defs
2+
3+ from __future__ import annotations
4+
5+
26"""
37This file does three things:
48- Contains the definition of SymNode
@@ -145,12 +149,12 @@ def compute_hint():
145149 )
146150 self .fx_node = tx_validation_en and fx_node
147151
148- def with_shape_env (self , shape_env : " ShapeEnv" ) -> " SymNode" :
152+ def with_shape_env (self , shape_env : ShapeEnv ) -> SymNode :
149153 return SymNode (
150154 self ._expr , shape_env , self .pytype , self ._hint , self .constant , self .fx_node
151155 )
152156
153- def _value_eq (self , other : " SymNode" ) -> bool :
157+ def _value_eq (self , other : SymNode ) -> bool :
154158 # Purposely don't include the shape_env in the eq.
155159 return (
156160 self ._expr == other ._expr
@@ -281,121 +285,121 @@ def _graph_repr(self) -> builtins.str:
281285
282286 # These methods call the metaprogrammed methods, they're hand written
283287 # here so we get good stack traces
284- def abs (self ) -> " SymNode" :
288+ def abs (self ) -> SymNode :
285289 return self ._abs () # type: ignore[attr-defined]
286290
287- def pos (self ) -> " SymNode" :
291+ def pos (self ) -> SymNode :
288292 return self ._pos () # type: ignore[attr-defined]
289293
290- def round (self , ndigits = None ) -> " SymNode" :
294+ def round (self , ndigits = None ) -> SymNode :
291295 return self ._round (ndigits ) # type: ignore[attr-defined]
292296
293- def trunc (self ) -> " SymNode" :
297+ def trunc (self ) -> SymNode :
294298 return self ._trunc () # type: ignore[attr-defined]
295299
296- def add (self , other ) -> " SymNode" :
300+ def add (self , other ) -> SymNode :
297301 return self ._add (other ) # type: ignore[attr-defined]
298302
299- def sub (self , other ) -> " SymNode" :
303+ def sub (self , other ) -> SymNode :
300304 return self ._sub (other ) # type: ignore[attr-defined]
301305
302- def mul (self , other ) -> " SymNode" :
306+ def mul (self , other ) -> SymNode :
303307 return self ._mul (other ) # type: ignore[attr-defined]
304308
305- def mod (self , other ) -> " SymNode" :
309+ def mod (self , other ) -> SymNode :
306310 return self ._mod (other ) # type: ignore[attr-defined]
307311
308- def float_pow (self , other ) -> " SymNode" :
312+ def float_pow (self , other ) -> SymNode :
309313 return self ._float_pow (other ) # type: ignore[attr-defined]
310314
311- def pow_by_natural (self , other ) -> " SymNode" :
315+ def pow_by_natural (self , other ) -> SymNode :
312316 return self ._pow_by_natural (other ) # type: ignore[attr-defined]
313317
314- def and_ (self , other ) -> " SymNode" :
318+ def and_ (self , other ) -> SymNode :
315319 return self ._and_ (other ) # type: ignore[attr-defined]
316320
317- def or_ (self , other ) -> " SymNode" :
321+ def or_ (self , other ) -> SymNode :
318322 return self ._or_ (other ) # type: ignore[attr-defined]
319323
320- def float_truediv (self , other ) -> " SymNode" :
324+ def float_truediv (self , other ) -> SymNode :
321325 return self ._float_truediv (other ) # type: ignore[attr-defined]
322326
323- def int_truediv (self , other ) -> " SymNode" :
327+ def int_truediv (self , other ) -> SymNode :
324328 return self ._int_truediv (other ) # type: ignore[attr-defined]
325329
326- def int_floordiv (self , other ) -> " SymNode" :
330+ def int_floordiv (self , other ) -> SymNode :
327331 return self ._int_floordiv (other ) # type: ignore[attr-defined]
328332
329- def lshift (self , other ) -> " SymNode" :
333+ def lshift (self , other ) -> SymNode :
330334 return self ._lshift (other ) # type: ignore[attr-defined]
331335
332- def rshift (self , other ) -> " SymNode" :
336+ def rshift (self , other ) -> SymNode :
333337 return self ._rshift (other ) # type: ignore[attr-defined]
334338
335- def sym_not (self ) -> " SymNode" : # noqa: F811
339+ def sym_not (self ) -> SymNode : # noqa: F811
336340 return self ._sym_not () # type: ignore[attr-defined]
337341
338- def eq (self , other ) -> " SymNode" :
342+ def eq (self , other ) -> SymNode :
339343 return self ._eq (other ) # type: ignore[attr-defined]
340344
341- def ne (self , other ) -> " SymNode" :
345+ def ne (self , other ) -> SymNode :
342346 return self ._ne (other ) # type: ignore[attr-defined]
343347
344- def gt (self , other ) -> " SymNode" :
348+ def gt (self , other ) -> SymNode :
345349 return self ._gt (other ) # type: ignore[attr-defined]
346350
347- def lt (self , other ) -> " SymNode" :
351+ def lt (self , other ) -> SymNode :
348352 return self ._lt (other ) # type: ignore[attr-defined]
349353
350- def le (self , other ) -> " SymNode" :
354+ def le (self , other ) -> SymNode :
351355 return self ._le (other ) # type: ignore[attr-defined]
352356
353- def ge (self , other ) -> " SymNode" :
357+ def ge (self , other ) -> SymNode :
354358 return self ._ge (other ) # type: ignore[attr-defined]
355359
356- def floor (self ) -> " SymNode" :
360+ def floor (self ) -> SymNode :
357361 return self ._floor () # type: ignore[attr-defined]
358362
359- def is_integer (self ) -> " SymNode" :
363+ def is_integer (self ) -> SymNode :
360364 return self ._is_integer () # type: ignore[attr-defined]
361365
362- def sym_float (self ) -> " SymNode" : # noqa: F811
366+ def sym_float (self ) -> SymNode : # noqa: F811
363367 return self ._sym_float () # type: ignore[attr-defined]
364368
365- def sym_int (self ) -> " SymNode" :
369+ def sym_int (self ) -> SymNode :
366370 return self ._sym_int () # type: ignore[attr-defined]
367371
368- def ceil (self ) -> " SymNode" :
372+ def ceil (self ) -> SymNode :
369373 return self ._ceil () # type: ignore[attr-defined]
370374
371- def neg (self ) -> " SymNode" :
375+ def neg (self ) -> SymNode :
372376 return self ._neg () # type: ignore[attr-defined]
373377
374- def sym_min (self , other ) -> " SymNode" : # noqa: F811
378+ def sym_min (self , other ) -> SymNode : # noqa: F811
375379 return self ._sym_min (other ) # type: ignore[attr-defined]
376380
377- def sym_max (self , other ) -> " SymNode" : # noqa: F811
381+ def sym_max (self , other ) -> SymNode : # noqa: F811
378382 return self ._sym_max (other ) # type: ignore[attr-defined]
379383
380- def sym_ite (self , then_val , else_val ) -> " SymNode" :
384+ def sym_ite (self , then_val , else_val ) -> SymNode :
381385 return self ._sym_ite (then_val , else_val ) # type: ignore[attr-defined]
382386
383- def is_contiguous (self , sizes , strides ) -> " SymNode" :
387+ def is_contiguous (self , sizes , strides ) -> SymNode :
384388 return self ._is_contiguous (sizes , strides ) # type: ignore[attr-defined]
385389
386- def is_channels_last_contiguous_2d (self , sizes , strides ) -> " SymNode" :
390+ def is_channels_last_contiguous_2d (self , sizes , strides ) -> SymNode :
387391 return self ._is_channels_last_contiguous_2d (sizes , strides ) # type: ignore[attr-defined]
388392
389- def is_channels_last_contiguous_3d (self , sizes , strides ) -> " SymNode" :
393+ def is_channels_last_contiguous_3d (self , sizes , strides ) -> SymNode :
390394 return self ._is_channels_last_contiguous_3d (sizes , strides ) # type: ignore[attr-defined]
391395
392- def is_channels_last_strides_2d (self , sizes , strides ) -> " SymNode" :
396+ def is_channels_last_strides_2d (self , sizes , strides ) -> SymNode :
393397 return self ._is_channels_last_strides_2d (sizes , strides ) # type: ignore[attr-defined]
394398
395- def is_channels_last_strides_3d (self , sizes , strides ) -> " SymNode" :
399+ def is_channels_last_strides_3d (self , sizes , strides ) -> SymNode :
396400 return self ._is_channels_last_strides_3d (sizes , strides ) # type: ignore[attr-defined]
397401
398- def is_non_overlapping_and_dense_indicator (self , sizes , strides ) -> " SymNode" :
402+ def is_non_overlapping_and_dense_indicator (self , sizes , strides ) -> SymNode :
399403 return self ._is_non_overlapping_and_dense_indicator (sizes , strides ) # type: ignore[attr-defined]
400404
401405 # Make C++ happy
@@ -409,7 +413,7 @@ def sym_and(self, other):
409413 def truediv (self , other ):
410414 return self .float_truediv (other )
411415
412- def floordiv (self , other ) -> " SymNode" :
416+ def floordiv (self , other ) -> SymNode :
413417 return self .int_floordiv (other )
414418
415419 # We didn't bind integer pow in C++
@@ -426,7 +430,7 @@ def int_(self):
426430 # functions consider factoring it out to be metaprogrammed too. Note that
427431 # some load bearing logic is directly in torch.sym_sum
428432
429- def sym_sum (self , args ) -> " SymNode" :
433+ def sym_sum (self , args ) -> SymNode :
430434 import sympy
431435
432436 # Inner impl
@@ -629,6 +633,7 @@ def fn(self):
629633 "asin" ,
630634 "acos" ,
631635 "atan" ,
636+ "log2" ,
632637)
633638for name in math_op_names :
634639 sym_name = f"sym_{ name } "
@@ -656,7 +661,7 @@ def fn(self):
656661bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
657662
658663# Methods that are only for float
659- only_float_magic_methods = {"is_integer" , "round" , "sym_int" }
664+ only_float_magic_methods = {"is_integer" , "round" , "sym_int" , "sym_log2" }
660665
661666
662667magic_methods_on_operator_with_trailing_underscore = {"and" , "or" }
0 commit comments