123123 builtin_dict_keys ,
124124 common_constant_types ,
125125 dict_keys ,
126- dict_keys_repr ,
127126 get_custom_getattr ,
128127 get_torch_function_mode_stack ,
129128 get_torch_function_mode_stack_at ,
@@ -422,7 +421,7 @@ def _get_closure_vars():
422421 "___odict_getitem" : collections .OrderedDict .__getitem__ ,
423422 "___key_to_id" : key_to_id ,
424423 "___dict_version" : dict_version ,
425- "___dict_contains" : lambda a , b : a in b ,
424+ "___dict_contains" : lambda a , b : dict . __contains__ ( b , a ) ,
426425 "___tuple_iterator_len" : tuple_iterator_len ,
427426 "___normalize_range_iter" : normalize_range_iter ,
428427 "___tuple_iterator_getitem" : tuple_iterator_getitem ,
@@ -1732,29 +1731,6 @@ def DUPLICATE_INPUT(self, guard, source_b):
17321731 get_verbose_code_parts (code , guard ),
17331732 )
17341733
1735- def DICT_KEYS (self , guard ):
1736- # Guard on the keys and their order
1737- ref = self .arg_ref (guard )
1738- value = self .get (guard .name )
1739-
1740- self .TYPE_MATCH (guard )
1741- code = []
1742- any_key_is_id = any (key_is_id (k ) for k in builtin_dict_keys (value ))
1743- const_keys_repr = dict_keys_repr (
1744- key_to_id (value ),
1745- local = is_from_local_source (guard .originating_source ),
1746- )
1747- if any_key_is_id :
1748- code .append (f"___key_to_id({ ref } ) == { const_keys_repr } " )
1749- else :
1750- code .append (f"list({ ref } .keys()) == { const_keys_repr } " )
1751-
1752- self ._set_guard_export_info (guard , code )
1753- if self .requires_key_order_guarding (guard .originating_source ):
1754- self .guard_on_dict_keys_and_order (value , guard )
1755- else :
1756- self .guard_on_dict_keys_and_ignore_order (value , guard )
1757-
17581734 def WEAKREF_ALIVE (self , guard ):
17591735 code = [f"{ self .arg_ref (guard )} is not None" ]
17601736
@@ -1763,11 +1739,18 @@ def WEAKREF_ALIVE(self, guard):
17631739 get_verbose_code_parts (code , guard )
17641740 )
17651741
1766- def DICT_CONST_KEYS (self , guard ):
1767- """Constant keys match """
1742+ def DICT_KEYS_MATCH (self , guard ):
1743+ """Insert guard to check that the keys of a dict are same """
17681744 ref = self .arg_ref (guard )
17691745 value = self .get (guard .name )
17701746
1747+ if value is torch .utils ._pytree .SUPPORTED_NODES :
1748+ # For SUPPORTED_NODES, we can guard on the dictionary version (PEP509).
1749+ self .DICT_VERSION (guard )
1750+ return
1751+
1752+ self .SEQUENCE_LENGTH (guard )
1753+
17711754 code = []
17721755 # Ensure that we call dict.keys and not value.keys (which can call
17731756 # overridden keys method). In the C++ guards, we relied on PyDict_Next
0 commit comments