66import re
77import subprocess
88import sys
9+ import time
910import unittest
1011from collections import defaultdict , deque , namedtuple , OrderedDict , UserDict
1112from dataclasses import dataclass
@@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl):
731732 with self .assertRaises (TypeError ):
732733 pytree_impl .treespec_dumps ("random_blurb" )
733734
735+ @parametrize (
736+ "pytree" ,
737+ [
738+ subtest (py_pytree , name = "py" ),
739+ subtest (cxx_pytree , name = "cxx" ),
740+ ],
741+ )
742+ def test_is_namedtuple (self , pytree ):
743+ DirectNamedTuple1 = namedtuple ("DirectNamedTuple1" , ["x" , "y" ])
744+
745+ class DirectNamedTuple2 (NamedTuple ):
746+ x : int
747+ y : int
748+
749+ class IndirectNamedTuple1 (DirectNamedTuple1 ):
750+ pass
751+
752+ class IndirectNamedTuple2 (DirectNamedTuple2 ):
753+ pass
754+
755+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple1 (0 , 1 )))
756+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple2 (0 , 1 )))
757+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple1 (0 , 1 )))
758+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple2 (0 , 1 )))
759+ self .assertFalse (pytree .is_namedtuple (time .gmtime ()))
760+ self .assertFalse (pytree .is_namedtuple ((0 , 1 )))
761+ self .assertFalse (pytree .is_namedtuple ([0 , 1 ]))
762+ self .assertFalse (pytree .is_namedtuple ({0 : 1 , 1 : 2 }))
763+ self .assertFalse (pytree .is_namedtuple ({0 , 1 }))
764+ self .assertFalse (pytree .is_namedtuple (1 ))
765+
766+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple1 ))
767+ self .assertTrue (pytree .is_namedtuple (DirectNamedTuple2 ))
768+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple1 ))
769+ self .assertTrue (pytree .is_namedtuple (IndirectNamedTuple2 ))
770+ self .assertFalse (pytree .is_namedtuple (time .struct_time ))
771+ self .assertFalse (pytree .is_namedtuple (tuple ))
772+ self .assertFalse (pytree .is_namedtuple (list ))
773+
774+ self .assertTrue (pytree .is_namedtuple_class (DirectNamedTuple1 ))
775+ self .assertTrue (pytree .is_namedtuple_class (DirectNamedTuple2 ))
776+ self .assertTrue (pytree .is_namedtuple_class (IndirectNamedTuple1 ))
777+ self .assertTrue (pytree .is_namedtuple_class (IndirectNamedTuple2 ))
778+ self .assertFalse (pytree .is_namedtuple_class (time .struct_time ))
779+ self .assertFalse (pytree .is_namedtuple_class (tuple ))
780+ self .assertFalse (pytree .is_namedtuple_class (list ))
781+
782+ @parametrize (
783+ "pytree" ,
784+ [
785+ subtest (py_pytree , name = "py" ),
786+ subtest (cxx_pytree , name = "cxx" ),
787+ ],
788+ )
789+ def test_is_structseq (self , pytree ):
790+ class FakeStructSeq (tuple ):
791+ n_fields = 2
792+ n_sequence_fields = 2
793+ n_unnamed_fields = 0
794+
795+ __slots__ = ()
796+ __match_args__ = ("x" , "y" )
797+
798+ def __new__ (cls , sequence ):
799+ return super ().__new__ (cls , sequence )
800+
801+ @property
802+ def x (self ):
803+ return self [0 ]
804+
805+ @property
806+ def y (self ):
807+ return self [1 ]
808+
809+ DirectNamedTuple1 = namedtuple ("DirectNamedTuple1" , ["x" , "y" ])
810+
811+ class DirectNamedTuple2 (NamedTuple ):
812+ x : int
813+ y : int
814+
815+ self .assertFalse (pytree .is_structseq (FakeStructSeq ((0 , 1 ))))
816+ self .assertTrue (pytree .is_structseq (time .gmtime ()))
817+ self .assertFalse (pytree .is_structseq (DirectNamedTuple1 (0 , 1 )))
818+ self .assertFalse (pytree .is_structseq (DirectNamedTuple2 (0 , 1 )))
819+ self .assertFalse (pytree .is_structseq ((0 , 1 )))
820+ self .assertFalse (pytree .is_structseq ([0 , 1 ]))
821+ self .assertFalse (pytree .is_structseq ({0 : 1 , 1 : 2 }))
822+ self .assertFalse (pytree .is_structseq ({0 , 1 }))
823+ self .assertFalse (pytree .is_structseq (1 ))
824+
825+ self .assertFalse (pytree .is_structseq (FakeStructSeq ))
826+ self .assertTrue (pytree .is_structseq (time .struct_time ))
827+ self .assertFalse (pytree .is_structseq (DirectNamedTuple1 ))
828+ self .assertFalse (pytree .is_structseq (DirectNamedTuple2 ))
829+ self .assertFalse (pytree .is_structseq (tuple ))
830+ self .assertFalse (pytree .is_structseq (list ))
831+
832+ self .assertFalse (pytree .is_structseq_class (FakeStructSeq ))
833+ self .assertTrue (
834+ pytree .is_structseq_class (time .struct_time ),
835+ )
836+ self .assertFalse (pytree .is_structseq_class (DirectNamedTuple1 ))
837+ self .assertFalse (pytree .is_structseq_class (DirectNamedTuple2 ))
838+ self .assertFalse (pytree .is_structseq_class (tuple ))
839+ self .assertFalse (pytree .is_structseq_class (list ))
840+
841+ # torch.return_types.* are all PyStructSequence types
842+ for cls in vars (torch .return_types ).values ():
843+ if isinstance (cls , type ) and issubclass (cls , tuple ):
844+ self .assertTrue (pytree .is_structseq (cls ))
845+ self .assertTrue (pytree .is_structseq_class (cls ))
846+ self .assertFalse (pytree .is_namedtuple (cls ))
847+ self .assertFalse (pytree .is_namedtuple_class (cls ))
848+
849+ inst = cls (range (cls .n_sequence_fields ))
850+ self .assertTrue (pytree .is_structseq (inst ))
851+ self .assertTrue (pytree .is_structseq (type (inst )))
852+ self .assertFalse (pytree .is_structseq_class (inst ))
853+ self .assertTrue (pytree .is_structseq_class (type (inst )))
854+ self .assertFalse (pytree .is_namedtuple (inst ))
855+ self .assertFalse (pytree .is_namedtuple_class (inst ))
856+ else :
857+ self .assertFalse (pytree .is_structseq (cls ))
858+ self .assertFalse (pytree .is_structseq_class (cls ))
859+ self .assertFalse (pytree .is_namedtuple (cls ))
860+ self .assertFalse (pytree .is_namedtuple_class (cls ))
861+
734862
735863class TestPythonPytree (TestCase ):
736864 def test_deprecated_register_pytree_node (self ):
@@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self):
9751103 serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point1" ,
9761104 )
9771105
978- spec = py_pytree .TreeSpec (
979- namedtuple , Point1 , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
980- )
1106+ spec = py_pytree .tree_structure (Point1 (1 , 2 ))
1107+ self .assertIs (spec .type , namedtuple )
9811108 roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
9821109 self .assertEqual (spec , roundtrip_spec )
9831110
@@ -990,18 +1117,28 @@ class Point2(NamedTuple):
9901117 serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point2" ,
9911118 )
9921119
993- spec = py_pytree .TreeSpec (
994- namedtuple , Point2 , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1120+ spec = py_pytree .tree_structure (Point2 (1 , 2 ))
1121+ self .assertIs (spec .type , namedtuple )
1122+ roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
1123+ self .assertEqual (spec , roundtrip_spec )
1124+
1125+ class Point3 (Point2 ):
1126+ pass
1127+
1128+ py_pytree ._register_namedtuple (
1129+ Point3 ,
1130+ serialized_type_name = "test_pytree.test_pytree_serialize_namedtuple.Point3" ,
9951131 )
1132+
1133+ spec = py_pytree .tree_structure (Point3 (1 , 2 ))
1134+ self .assertIs (spec .type , namedtuple )
9961135 roundtrip_spec = py_pytree .treespec_loads (py_pytree .treespec_dumps (spec ))
9971136 self .assertEqual (spec , roundtrip_spec )
9981137
9991138 def test_pytree_serialize_namedtuple_bad (self ):
10001139 DummyType = namedtuple ("DummyType" , ["x" , "y" ])
10011140
1002- spec = py_pytree .TreeSpec (
1003- namedtuple , DummyType , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1004- )
1141+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
10051142
10061143 with self .assertRaisesRegex (
10071144 NotImplementedError , "Please register using `_register_namedtuple`"
@@ -1020,9 +1157,7 @@ def __init__(self, x, y):
10201157 lambda xs , _ : DummyType (* xs ),
10211158 )
10221159
1023- spec = py_pytree .TreeSpec (
1024- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1025- )
1160+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
10261161 with self .assertRaisesRegex (
10271162 NotImplementedError , "No registered serialization name"
10281163 ):
@@ -1042,9 +1177,7 @@ def __init__(self, x, y):
10421177 to_dumpable_context = lambda context : "moo" ,
10431178 from_dumpable_context = lambda dumpable_context : None ,
10441179 )
1045- spec = py_pytree .TreeSpec (
1046- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1047- )
1180+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
10481181 serialized_spec = py_pytree .treespec_dumps (spec , 1 )
10491182 self .assertIn ("moo" , serialized_spec )
10501183 roundtrip_spec = py_pytree .treespec_loads (serialized_spec )
@@ -1082,9 +1215,7 @@ def __init__(self, x, y):
10821215 from_dumpable_context = lambda dumpable_context : None ,
10831216 )
10841217
1085- spec = py_pytree .TreeSpec (
1086- DummyType , None , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1087- )
1218+ spec = py_pytree .tree_structure (DummyType (1 , 2 ))
10881219
10891220 with self .assertRaisesRegex (
10901221 TypeError , "Object of type type is not JSON serializable"
@@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self):
10951226 import json
10961227
10971228 Point = namedtuple ("Point" , ["x" , "y" ])
1098- spec = py_pytree .TreeSpec (
1099- namedtuple , Point , [py_pytree .LeafSpec (), py_pytree .LeafSpec ()]
1100- )
1229+ spec = py_pytree .tree_structure (Point (1 , 2 ))
11011230 py_pytree ._register_namedtuple (
11021231 Point ,
11031232 serialized_type_name = "test_pytree.test_pytree_serialize_bad_protocol.Point" ,
0 commit comments