Skip to content

Commit 4baec15

Browse files
author
The TensorFlow Datasets Authors
committed
Create a new dataset object for each _generate_example call.
PiperOrigin-RevId: 716664290
1 parent 855b1cd commit 4baec15

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,17 @@ def __init__(
217217
"""
218218
if mapping is None:
219219
mapping = {}
220-
self.dataset = mlc.Dataset(jsonld, mapping=mapping)
221-
self.name = croissant_utils.get_tfds_dataset_name(self.dataset)
222-
self.metadata = self.dataset.metadata
220+
self.jsonld = jsonld
221+
self.mapping = mapping
222+
dataset = mlc.Dataset(jsonld, mapping=mapping)
223+
self.name = croissant_utils.get_tfds_dataset_name(dataset)
224+
self.metadata = dataset.metadata
223225

224226
# In TFDS, version is a mandatory attribute, while in Croissant it is only a
225227
# recommended attribute. If the version is unspecified in Croissant, we set
226228
# it to `1.0.0` in TFDS.
227229
self.VERSION = version_lib.Version( # pylint: disable=invalid-name
228-
overwrite_version or self.dataset.metadata.version or '1.0.0'
230+
overwrite_version or self.metadata.version or '1.0.0'
229231
)
230232
self.RELEASE_NOTES = {} # pylint: disable=invalid-name
231233

@@ -260,11 +262,11 @@ def builder_config(self) -> dataset_builder.BuilderConfig:
260262
def _info(self) -> dataset_info.DatasetInfo:
261263
return dataset_info.DatasetInfo(
262264
builder=self,
263-
description=self.dataset.metadata.description,
265+
description=self.metadata.description,
264266
features=self.get_features(),
265-
homepage=self.dataset.metadata.url,
266-
citation=self.dataset.metadata.cite_as,
267-
license=_get_license(self.dataset.metadata),
267+
homepage=self.metadata.url,
268+
citation=self.metadata.cite_as,
269+
license=_get_license(self.metadata),
268270
disable_shuffling=self._disable_shuffling,
269271
)
270272

@@ -331,7 +333,8 @@ def _generate_examples(
331333
record_set = croissant_utils.get_record_set(
332334
self.builder_config.name, metadata=self.metadata
333335
)
334-
records = self.dataset.records(record_set.id, filters=filters)
336+
dataset = mlc.Dataset(self.jsonld, mapping=self.mapping)
337+
records = dataset.records(record_set.id, filters=filters)
335338

336339
def convert_to_tfds_format(
337340
global_index: int,

tensorflow_datasets/core/dataset_builders/croissant_builder_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,10 @@ def test_croissant_builder(crs_builder):
242242
def test_download_and_prepare(crs_builder, expected_entries, split_name):
243243
crs_builder.download_and_prepare()
244244
data_source = crs_builder.as_data_source(split=split_name)
245-
assert len(data_source) == 1
245+
expected_entries = [
246+
entry for entry in expected_entries if entry["split"] == split_name
247+
]
248+
assert len(data_source) == len(expected_entries) == 1
246249
for entry, expected_entry in zip(data_source, expected_entries):
247250
assert entry["index"] == expected_entry["index"]
248251
assert entry["text"].decode() == expected_entry["text"]

0 commit comments

Comments
 (0)