11# mypy: allow-untyped-defs
2-
3- from __future__ import annotations
4-
5-
62"""
73This file does three things:
84- Contains the definition of SymNode
@@ -149,12 +145,12 @@ def compute_hint():
149145 )
150146 self .fx_node = tx_validation_en and fx_node
151147
152- def with_shape_env (self , shape_env : ShapeEnv ) -> SymNode :
148+ def with_shape_env (self , shape_env : " ShapeEnv" ) -> " SymNode" :
153149 return SymNode (
154150 self ._expr , shape_env , self .pytype , self ._hint , self .constant , self .fx_node
155151 )
156152
157- def _value_eq (self , other : SymNode ) -> bool :
153+ def _value_eq (self , other : " SymNode" ) -> bool :
158154 # Purposely don't include the shape_env in the eq.
159155 return (
160156 self ._expr == other ._expr
@@ -285,121 +281,121 @@ def _graph_repr(self) -> builtins.str:
285281
286282 # These methods call the metaprogrammed methods, they're hand written
287283 # here so we get good stack traces
288- def abs (self ) -> SymNode :
284+ def abs (self ) -> " SymNode" :
289285 return self ._abs () # type: ignore[attr-defined]
290286
291- def pos (self ) -> SymNode :
287+ def pos (self ) -> " SymNode" :
292288 return self ._pos () # type: ignore[attr-defined]
293289
294- def round (self , ndigits = None ) -> SymNode :
290+ def round (self , ndigits = None ) -> " SymNode" :
295291 return self ._round (ndigits ) # type: ignore[attr-defined]
296292
297- def trunc (self ) -> SymNode :
293+ def trunc (self ) -> " SymNode" :
298294 return self ._trunc () # type: ignore[attr-defined]
299295
300- def add (self , other ) -> SymNode :
296+ def add (self , other ) -> " SymNode" :
301297 return self ._add (other ) # type: ignore[attr-defined]
302298
303- def sub (self , other ) -> SymNode :
299+ def sub (self , other ) -> " SymNode" :
304300 return self ._sub (other ) # type: ignore[attr-defined]
305301
306- def mul (self , other ) -> SymNode :
302+ def mul (self , other ) -> " SymNode" :
307303 return self ._mul (other ) # type: ignore[attr-defined]
308304
309- def mod (self , other ) -> SymNode :
305+ def mod (self , other ) -> " SymNode" :
310306 return self ._mod (other ) # type: ignore[attr-defined]
311307
312- def float_pow (self , other ) -> SymNode :
308+ def float_pow (self , other ) -> " SymNode" :
313309 return self ._float_pow (other ) # type: ignore[attr-defined]
314310
315- def pow_by_natural (self , other ) -> SymNode :
311+ def pow_by_natural (self , other ) -> " SymNode" :
316312 return self ._pow_by_natural (other ) # type: ignore[attr-defined]
317313
318- def and_ (self , other ) -> SymNode :
314+ def and_ (self , other ) -> " SymNode" :
319315 return self ._and_ (other ) # type: ignore[attr-defined]
320316
321- def or_ (self , other ) -> SymNode :
317+ def or_ (self , other ) -> " SymNode" :
322318 return self ._or_ (other ) # type: ignore[attr-defined]
323319
324- def float_truediv (self , other ) -> SymNode :
320+ def float_truediv (self , other ) -> " SymNode" :
325321 return self ._float_truediv (other ) # type: ignore[attr-defined]
326322
327- def int_truediv (self , other ) -> SymNode :
323+ def int_truediv (self , other ) -> " SymNode" :
328324 return self ._int_truediv (other ) # type: ignore[attr-defined]
329325
330- def int_floordiv (self , other ) -> SymNode :
326+ def int_floordiv (self , other ) -> " SymNode" :
331327 return self ._int_floordiv (other ) # type: ignore[attr-defined]
332328
333- def lshift (self , other ) -> SymNode :
329+ def lshift (self , other ) -> " SymNode" :
334330 return self ._lshift (other ) # type: ignore[attr-defined]
335331
336- def rshift (self , other ) -> SymNode :
332+ def rshift (self , other ) -> " SymNode" :
337333 return self ._rshift (other ) # type: ignore[attr-defined]
338334
339- def sym_not (self ) -> SymNode : # noqa: F811
335+ def sym_not (self ) -> " SymNode" : # noqa: F811
340336 return self ._sym_not () # type: ignore[attr-defined]
341337
342- def eq (self , other ) -> SymNode :
338+ def eq (self , other ) -> " SymNode" :
343339 return self ._eq (other ) # type: ignore[attr-defined]
344340
345- def ne (self , other ) -> SymNode :
341+ def ne (self , other ) -> " SymNode" :
346342 return self ._ne (other ) # type: ignore[attr-defined]
347343
348- def gt (self , other ) -> SymNode :
344+ def gt (self , other ) -> " SymNode" :
349345 return self ._gt (other ) # type: ignore[attr-defined]
350346
351- def lt (self , other ) -> SymNode :
347+ def lt (self , other ) -> " SymNode" :
352348 return self ._lt (other ) # type: ignore[attr-defined]
353349
354- def le (self , other ) -> SymNode :
350+ def le (self , other ) -> " SymNode" :
355351 return self ._le (other ) # type: ignore[attr-defined]
356352
357- def ge (self , other ) -> SymNode :
353+ def ge (self , other ) -> " SymNode" :
358354 return self ._ge (other ) # type: ignore[attr-defined]
359355
360- def floor (self ) -> SymNode :
356+ def floor (self ) -> " SymNode" :
361357 return self ._floor () # type: ignore[attr-defined]
362358
363- def is_integer (self ) -> SymNode :
359+ def is_integer (self ) -> " SymNode" :
364360 return self ._is_integer () # type: ignore[attr-defined]
365361
366- def sym_float (self ) -> SymNode : # noqa: F811
362+ def sym_float (self ) -> " SymNode" : # noqa: F811
367363 return self ._sym_float () # type: ignore[attr-defined]
368364
369- def sym_int (self ) -> SymNode :
365+ def sym_int (self ) -> " SymNode" :
370366 return self ._sym_int () # type: ignore[attr-defined]
371367
372- def ceil (self ) -> SymNode :
368+ def ceil (self ) -> " SymNode" :
373369 return self ._ceil () # type: ignore[attr-defined]
374370
375- def neg (self ) -> SymNode :
371+ def neg (self ) -> " SymNode" :
376372 return self ._neg () # type: ignore[attr-defined]
377373
378- def sym_min (self , other ) -> SymNode : # noqa: F811
374+ def sym_min (self , other ) -> " SymNode" : # noqa: F811
379375 return self ._sym_min (other ) # type: ignore[attr-defined]
380376
381- def sym_max (self , other ) -> SymNode : # noqa: F811
377+ def sym_max (self , other ) -> " SymNode" : # noqa: F811
382378 return self ._sym_max (other ) # type: ignore[attr-defined]
383379
384- def sym_ite (self , then_val , else_val ) -> SymNode :
380+ def sym_ite (self , then_val , else_val ) -> " SymNode" :
385381 return self ._sym_ite (then_val , else_val ) # type: ignore[attr-defined]
386382
387- def is_contiguous (self , sizes , strides ) -> SymNode :
383+ def is_contiguous (self , sizes , strides ) -> " SymNode" :
388384 return self ._is_contiguous (sizes , strides ) # type: ignore[attr-defined]
389385
390- def is_channels_last_contiguous_2d (self , sizes , strides ) -> SymNode :
386+ def is_channels_last_contiguous_2d (self , sizes , strides ) -> " SymNode" :
391387 return self ._is_channels_last_contiguous_2d (sizes , strides ) # type: ignore[attr-defined]
392388
393- def is_channels_last_contiguous_3d (self , sizes , strides ) -> SymNode :
389+ def is_channels_last_contiguous_3d (self , sizes , strides ) -> " SymNode" :
394390 return self ._is_channels_last_contiguous_3d (sizes , strides ) # type: ignore[attr-defined]
395391
396- def is_channels_last_strides_2d (self , sizes , strides ) -> SymNode :
392+ def is_channels_last_strides_2d (self , sizes , strides ) -> " SymNode" :
397393 return self ._is_channels_last_strides_2d (sizes , strides ) # type: ignore[attr-defined]
398394
399- def is_channels_last_strides_3d (self , sizes , strides ) -> SymNode :
395+ def is_channels_last_strides_3d (self , sizes , strides ) -> " SymNode" :
400396 return self ._is_channels_last_strides_3d (sizes , strides ) # type: ignore[attr-defined]
401397
402- def is_non_overlapping_and_dense_indicator (self , sizes , strides ) -> SymNode :
398+ def is_non_overlapping_and_dense_indicator (self , sizes , strides ) -> " SymNode" :
403399 return self ._is_non_overlapping_and_dense_indicator (sizes , strides ) # type: ignore[attr-defined]
404400
405401 # Make C++ happy
@@ -413,7 +409,7 @@ def sym_and(self, other):
413409 def truediv (self , other ):
414410 return self .float_truediv (other )
415411
416- def floordiv (self , other ) -> SymNode :
412+ def floordiv (self , other ) -> " SymNode" :
417413 return self .int_floordiv (other )
418414
419415 # We didn't bind integer pow in C++
@@ -633,7 +629,6 @@ def fn(self):
633629 "asin" ,
634630 "acos" ,
635631 "atan" ,
636- "log2" ,
637632)
638633for name in math_op_names :
639634 sym_name = f"sym_{ name } "
@@ -661,7 +656,7 @@ def fn(self):
661656bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
662657
663658# Methods that are only for float
664- only_float_magic_methods = {"is_integer" , "round" , "sym_int" , "sym_log2" }
659+ only_float_magic_methods = {"is_integer" , "round" , "sym_int" }
665660
666661
667662magic_methods_on_operator_with_trailing_underscore = {"and" , "or" }
0 commit comments