@@ -83,13 +83,25 @@ class PickableDataSourceMock(mock.MagicMock):
83
83
"""Makes MagicMock pickable in order to work with multiprocessing in Grain."""
84
84
85
85
def __getstate__ (self ):
86
- return {'num_examples' : len (self ), 'generator' : self ._generator }
86
+ return {
87
+ 'num_examples' : len (self ),
88
+ 'generator' : self ._generator ,
89
+ 'serialize_example' : self ._serialize_example ,
90
+ }
87
91
88
92
def __setstate__ (self , state ):
89
- num_examples , generator = state ['num_examples' ], state ['generator' ]
93
+ num_examples , generator , serialize_example = (
94
+ state ['num_examples' ],
95
+ state ['generator' ],
96
+ state ['serialize_example' ],
97
+ )
90
98
self .__len__ .return_value = num_examples
91
- self .__getitem__ = functools .partial (_getitem , generator = generator )
92
- self .__getitems__ = functools .partial (_getitems , generator = generator )
99
+ self .__getitem__ = functools .partial (
100
+ _getitem , generator = generator , serialize_example = serialize_example
101
+ )
102
+ self .__getitems__ = functools .partial (
103
+ _getitems , generator = generator , serialize_example = serialize_example
104
+ )
93
105
94
106
def __reduce__ (self ):
95
107
return (PickableDataSourceMock , (), self .__getstate__ ())
@@ -99,50 +111,33 @@ def _getitem(
99
111
self ,
100
112
record_key : int ,
101
113
generator : RandomFakeGenerator ,
102
- serialized : bool = False ,
114
+ serialize_example = None ,
103
115
) -> Any :
104
116
"""Function to overwrite __getitem__ in data sources."""
117
+ del self
105
118
example = generator [record_key ]
106
- if serialized :
119
+ if serialize_example :
107
120
# Return serialized raw bytes
108
- return self . dataset_info . features . serialize_example (example )
121
+ return serialize_example (example )
109
122
return example
110
123
111
124
112
125
def _getitems (
113
126
self ,
114
127
record_keys : Sequence [int ],
115
128
generator : RandomFakeGenerator ,
116
- serialized : bool = False ,
129
+ serialize_example = None ,
117
130
) -> Sequence [Any ]:
118
131
"""Function to overwrite __getitems__ in data sources."""
119
132
items = [
120
- _getitem (self , record_key , generator , serialized = serialized )
133
+ _getitem (self , record_key , generator , serialize_example = serialize_example )
121
134
for record_key in record_keys
122
135
]
123
- if serialized :
136
+ if serialize_example :
124
137
return np .array (items )
125
138
return items
126
139
127
140
128
- def _deserialize_example_np (serialized_example , * , decoders = None ):
129
- """Function to overwrite dataset_info.features.deserialize_example_np.
130
-
131
- Warning: this has to be defined in the outer scope in order for the function
132
- to be pickable.
133
-
134
- Args:
135
- serialized_example: the example to deserialize.
136
- decoders: optional decoders.
137
-
138
- Returns:
139
- The serialized example, because deserialization is taken care by
140
- RandomFakeGenerator.
141
- """
142
- del decoders
143
- return serialized_example
144
-
145
-
146
141
class MockPolicy (enum .Enum ):
147
142
"""Strategy to use with `tfds.testing.mock_data` to mock the dataset.
148
143
@@ -385,21 +380,27 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
385
380
# Force ARRAY_RECORD as the default file_format.
386
381
return_value = file_adapters .FileFormat .ARRAY_RECORD ,
387
382
):
388
- self . info . features . deserialize_example_np = _deserialize_example_np
383
+ # Make mock_data_source pickable with a given len:
389
384
mock_data_source .return_value .__len__ .return_value = num_examples
385
+ # Make mock_data_source pickable with a given generator:
390
386
mock_data_source .return_value ._generator = ( # pylint:disable=protected-access
391
387
generator
392
388
)
389
+ # Make mock_data_source pickable with a given serialize_example:
390
+ mock_data_source .return_value ._serialize_example = ( # pylint:disable=protected-access
391
+ self .info .features .serialize_example
392
+ )
393
+ serialize_example = self .info .features .serialize_example
393
394
mock_data_source .return_value .__getitem__ = functools .partial (
394
- _getitem , generator = generator
395
+ _getitem , generator = generator , serialize_example = serialize_example
395
396
)
396
397
mock_data_source .return_value .__getitems__ = functools .partial (
397
- _getitems , generator = generator
398
+ _getitems , generator = generator , serialize_example = serialize_example
398
399
)
399
400
400
401
def build_single_data_source (split ):
401
402
single_data_source = array_record .ArrayRecordDataSource (
402
- dataset_info = self . info , split = split , decoders = decoders
403
+ dataset_builder = self , split = split , decoders = decoders
403
404
)
404
405
return single_data_source
405
406
0 commit comments