Skip to content

Commit 9d6090d

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Simplify ReadFromCroissant by removing the pipeline argument and making it a PCollection.
PiperOrigin-RevId: 702731977
1 parent 0419f1a commit 9d6090d

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def _split_generators(
288288
dl_manager: download.DownloadManager,
289289
pipeline: beam.Pipeline,
290290
) -> dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
291+
del dl_manager # unused
292+
del pipeline # unused
291293
# If a split recordset is joined for the required record set, we generate
292294
# splits accordingly. Otherwise, it generates a single `default` split with
293295
# all the records.
@@ -302,7 +304,6 @@ def _split_generators(
302304
split_key = split_reference.reference_field.references.field
303305
return {
304306
split[split_key]: self._generate_examples(
305-
pipeline=pipeline,
306307
filters={
307308
**self._filters,
308309
split_reference.reference_field.id: split[split_key],
@@ -311,21 +312,15 @@ def _split_generators(
311312
for split in split_reference.split_record_set.data
312313
}
313314
else:
314-
return {
315-
'default': self._generate_examples(
316-
pipeline=pipeline, filters=self._filters
317-
)
318-
}
315+
return {'default': self._generate_examples(filters=self._filters)}
319316

320317
def _generate_examples(
321318
self,
322-
pipeline: beam.Pipeline,
323319
filters: dict[str, Any],
324320
) -> beam.PTransform:
325321
"""Generates the examples for the given record set.
326322
327323
Args:
328-
pipeline: The Beam pipeline.
329324
filters: A dict of filters to apply to the records. The keys should be
330325
field names and the values should be the values to filter by. If a
331326
record matches all the filters, it will be included in the dataset.
@@ -354,10 +349,12 @@ def convert_to_tfds_format(
354349
conversion_utils.to_tfds_value(record, features),
355350
)
356351

357-
return records.beam_reader(
358-
pipeline=pipeline
359-
) | f'Convert to TFDS format for filters: {json.dumps(filters)}' >> beam.MapTuple(
360-
convert_to_tfds_format,
361-
features=self.info.features,
362-
record_set_id=record_set.id,
352+
return (
353+
records.beam_reader()
354+
| f'Convert to TFDS format for filters: {json.dumps(filters)}'
355+
>> beam.MapTuple(
356+
convert_to_tfds_format,
357+
features=self.info.features,
358+
record_set_id=record_set.id,
359+
)
363360
)

0 commit comments

Comments
 (0)