2828import types
2929from sys import intern
3030from types import CodeType , EllipsisType
31- from typing import Optional
31+ from typing import Any , Optional , Set , Union
3232
3333from xdis .codetype import Code2 , Code3 , Code15
3434from xdis .unmarshal import (
@@ -88,11 +88,16 @@ class _Marshaller:
8888 dispatch = {}
8989
9090 def __init__ (
91- self , writefunc , python_version : tuple , is_pypy : Optional [bool ] = None
91+ self ,
92+ writefunc ,
93+ python_version : tuple ,
94+ is_pypy : Optional [bool ] = None ,
95+ collection_order = {},
9296 ) -> None :
9397 self ._write = writefunc
9498 self .python_version = python_version
9599 self .is_pypy = is_pypy
100+ self .collection_order = collection_order
96101
97102 def dump (self , x ) -> None :
98103 if (
@@ -417,6 +422,7 @@ def dump_code3(self, x) -> None:
417422
418423 dispatch [Code3 ] = dump_code3
419424
425+ # FIXME: this is wrong.
420426 try :
421427 if PYTHON3 :
422428 dispatch [types .CodeType ] = dump_code3
@@ -425,22 +431,32 @@ def dump_code3(self, x) -> None:
425431 except NameError :
426432 pass
427433
428- def dump_set (self , x ) -> None :
429- self ._write (TYPE_SET )
430- self .w_long (len (x ))
431- for each in x :
434+ def dump_collection (self , type_code : str , bag : Union [frozenset , set , dict ]) -> None :
435+ """
436+ Save marshalled version of frozenset fs.
437+ Use self.collection_order, to ensure that the order
438+ or set elements that may have appeared from unmarshalling the appears
439+ the same way. This helps roundtrip checking, among possibly other things.
440+ """
441+ self ._write (type_code )
442+ self .w_long (len (bag ))
443+ collection = self .collection_order .get (bag , bag )
444+ for each in collection :
432445 self .dump (each )
433446
434- try :
435- dispatch [set ] = dump_set
436- except NameError :
437- pass
447+ def dump_set (self , s : Set [Any ]) -> None :
448+ """
449+ Save marshalled version of set s.
450+ """
451+ self .dump_collection (TYPE_SET , s )
438452
439- def dump_frozenset (self , x ) -> None :
440- self ._write (TYPE_FROZENSET )
441- self .w_long (len (x ))
442- for each in x :
443- self .dump (each )
453+ dispatch [set ] = dump_set
454+
455+ def dump_frozenset (self , fs : frozenset ) -> None :
456+ """
457+ Save marshalled version of frozenset fs.
458+ """
459+ self .dump_collection (TYPE_FROZENSET , fs )
444460
445461 try :
446462 dispatch [frozenset ] = dump_frozenset
@@ -1103,7 +1119,13 @@ def dumps(
11031119 is_pypy : Optional [bool ] = None ,
11041120) -> bytes | str :
11051121 buffer = []
1106- m = _Marshaller (buffer .append , python_version = python_version , is_pypy = is_pypy )
1122+ collection_order = x .collection_order if hasattr (x , "collection_order" ) else {}
1123+ m = _Marshaller (
1124+ buffer .append ,
1125+ python_version = python_version ,
1126+ is_pypy = is_pypy ,
1127+ collection_order = collection_order ,
1128+ )
11071129 m .dump (x )
11081130 if python_version :
11091131 is_python3 = python_version >= (3 , 0 )
0 commit comments