|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 |
| -"""Tests for tensorflow_datasets.core.writer.""" |
17 |
| - |
18 | 16 | import json
|
19 | 17 | import os
|
| 18 | +import tempfile |
20 | 19 | from typing import Optional
|
21 | 20 | from unittest import mock
|
22 | 21 |
|
23 | 22 | from absl.testing import parameterized
|
| 23 | +import apache_beam as beam |
24 | 24 | from etils import epath
|
25 | 25 | import tensorflow as tf
|
26 | 26 | from tensorflow_datasets import testing
|
27 | 27 | from tensorflow_datasets.core import dataset_utils
|
28 | 28 | from tensorflow_datasets.core import example_parser
|
29 | 29 | from tensorflow_datasets.core import file_adapters
|
30 |
| -from tensorflow_datasets.core import lazy_imports_lib |
31 | 30 | from tensorflow_datasets.core import naming
|
32 | 31 | from tensorflow_datasets.core import writer as writer_lib
|
33 | 32 | from tensorflow_datasets.core.utils import shard_utils
|
@@ -409,6 +408,10 @@ def test_too_small_split(self):
|
409 | 408 | self._write(to_write=to_write)
|
410 | 409 |
|
411 | 410 |
|
| 411 | +def _get_runner() -> beam.runners.PipelineRunner: |
| 412 | + return beam.runners.DirectRunner() |
| 413 | + |
| 414 | + |
412 | 415 | class TfrecordsWriterBeamTest(testing.TestCase):
|
413 | 416 | NUM_SHARDS = 3
|
414 | 417 | RECORDS_TO_WRITE = [(i, str(i).encode('utf-8')) for i in range(10)]
|
@@ -455,7 +458,6 @@ def _write(
|
455 | 458 | shard_config = shard_config or shard_utils.ShardConfig(
|
456 | 459 | num_shards=self.NUM_SHARDS
|
457 | 460 | )
|
458 |
| - beam = lazy_imports_lib.lazy_imports.apache_beam |
459 | 461 | writer = writer_lib.BeamWriter(
|
460 | 462 | serializer=testing.DummySerializer('dummy specs'),
|
461 | 463 | filename_template=filename_template,
|
@@ -581,6 +583,50 @@ def test_write_tfrecord_sorted_by_key_with_holes(self):
|
581 | 583 | self.assertEmpty(all_indices)
|
582 | 584 |
|
583 | 585 |
|
| 586 | +class NoShuffleBeamWriterTest(parameterized.TestCase): |
| 587 | + |
| 588 | + @parameterized.named_parameters( |
| 589 | + ('tfrecord', file_adapters.FileFormat.TFRECORD), |
| 590 | + ) |
| 591 | + def test_write_beam(self, file_format: file_adapters.FileFormat): |
| 592 | + |
| 593 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 594 | + tmp_dir = epath.Path(tmp_dir) |
| 595 | + filename_template = naming.ShardedFileTemplate( |
| 596 | + dataset_name='foo', |
| 597 | + split='train', |
| 598 | + filetype_suffix=file_format.file_suffix, |
| 599 | + data_dir=tmp_dir, |
| 600 | + ) |
| 601 | + writer = writer_lib.NoShuffleBeamWriter( |
| 602 | + serializer=testing.DummySerializer('dummy specs'), |
| 603 | + filename_template=filename_template, |
| 604 | + file_format=file_format, |
| 605 | + ) |
| 606 | + to_write = [(i, str(i).encode('utf-8')) for i in range(10)] |
| 607 | + # Here we need to disable type check as `beam.Create` is not capable of |
| 608 | + # inferring the type of the PCollection elements. |
| 609 | + options = beam.options.pipeline_options.PipelineOptions( |
| 610 | + pipeline_type_check=False |
| 611 | + ) |
| 612 | + with beam.Pipeline(options=options, runner=_get_runner()) as pipeline: |
| 613 | + |
| 614 | + @beam.ptransform_fn |
| 615 | + def _build_pcollection(pipeline): |
| 616 | + pcollection = pipeline | 'Start' >> beam.Create(to_write) |
| 617 | + return writer.write_from_pcollection(pcollection) |
| 618 | + |
| 619 | + _ = pipeline | 'test' >> _build_pcollection() # pylint: disable=no-value-for-parameter |
| 620 | + shard_lengths, total_size = writer.finalize() |
| 621 | + self.assertNotEmpty(shard_lengths) |
| 622 | + self.assertEqual(sum(shard_lengths), 10) |
| 623 | + self.assertGreater(total_size, 10) |
| 624 | + files = list(tmp_dir.iterdir()) |
| 625 | + self.assertGreaterEqual(len(files), 1) |
| 626 | + for f in files: |
| 627 | + self.assertIn(file_format.file_suffix, f.name) |
| 628 | + |
| 629 | + |
584 | 630 | class CustomExampleWriter(writer_lib.ExampleWriter):
|
585 | 631 |
|
586 | 632 | def __init__(self):
|
|
0 commit comments