Skip to content

Commit 6b93631

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Do several small refactorings
PiperOrigin-RevId: 704717347
1 parent cfeb104 commit 6b93631

File tree

6 files changed

+37
-31
lines changed

6 files changed

+37
-31
lines changed

tensorflow_datasets/core/data_sources/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import typing
2121
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar
2222

23+
from absl import logging
2324
from tensorflow_datasets.core import dataset_info as dataset_info_lib
2425
from tensorflow_datasets.core import decode
2526
from tensorflow_datasets.core import splits as splits_lib

tensorflow_datasets/core/dataset_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,7 +1543,6 @@ def _as_dataset(
15431543
file_format = file_adapters.FileFormat.from_value(file_format)
15441544

15451545
reader = reader_lib.Reader(
1546-
self.data_dir,
15471546
example_specs=example_specs,
15481547
file_format=file_format,
15491548
)

tensorflow_datasets/core/reader.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Sequence
2021
import functools
2122
import os
2223
import re
23-
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
24+
from typing import Any, Callable, NamedTuple
2425

2526
from absl import logging
2627
import numpy as np
@@ -63,7 +64,7 @@ def _get_dataset_from_filename(
6364
do_take: bool,
6465
file_format: file_adapters.FileFormat,
6566
add_tfds_id: bool,
66-
override_buffer_size: Optional[int] = None,
67+
override_buffer_size: int | None = None,
6768
) -> tf.data.Dataset:
6869
"""Returns a tf.data.Dataset instance from given instructions."""
6970
ds = file_adapters.ADAPTER_FOR_FORMAT[file_format].make_tf_data(
@@ -361,39 +362,38 @@ def _verify_read_config_for_ordered_dataset(
361362
logging.warning(error_message)
362363

363364

364-
class Reader(object):
365+
class Reader:
365366
"""Build a tf.data.Dataset object out of Instruction instance(s).
366367
367368
This class should not typically be exposed to the TFDS user.
368369
"""
369370

370371
def __init__(
371372
self,
372-
path, # TODO(b/216427814) remove this as it isn't used anymore
373373
example_specs,
374-
file_format=file_adapters.DEFAULT_FILE_FORMAT,
374+
file_format: (
375+
str | file_adapters.FileFormat
376+
) = file_adapters.DEFAULT_FILE_FORMAT,
375377
):
376378
"""Initializes Reader.
377379
378380
Args:
379-
path (str): path where tfrecords are stored.
380381
example_specs: spec to build ExampleParser.
381382
file_format: file_adapters.FileFormat, format of the record files in which
382383
the dataset will be read/written from.
383384
"""
384-
self._path = path
385385
self._parser = example_parser.ExampleParser(example_specs)
386-
self._file_format = file_format
386+
self._file_format = file_adapters.FileFormat.from_value(file_format)
387387

388388
def read(
389389
self,
390390
*,
391391
instructions: Tree[splits_lib.SplitArg],
392-
split_infos: List[splits_lib.SplitInfo],
392+
split_infos: Sequence[splits_lib.SplitInfo],
393393
read_config: read_config_lib.ReadConfig,
394394
shuffle_files: bool,
395395
disable_shuffling: bool = False,
396-
decode_fn: Optional[DecodeFn] = None,
396+
decode_fn: DecodeFn | None = None,
397397
) -> Tree[tf.data.Dataset]:
398398
"""Returns tf.data.Dataset instance(s).
399399
@@ -417,8 +417,11 @@ def read(
417417

418418
splits_dict = splits_lib.SplitDict(split_infos=split_infos)
419419

420-
def _read_instruction_to_ds(instruction):
421-
file_instructions = splits_dict[instruction].file_instructions
420+
def _read_instruction_to_ds(
421+
instruction: splits_lib.SplitArg,
422+
) -> tf.data.Dataset:
423+
split_info = splits_dict[instruction]
424+
file_instructions = split_info.file_instructions
422425
return self.read_files(
423426
file_instructions,
424427
read_config=read_config,
@@ -436,7 +439,7 @@ def read_files(
436439
read_config: read_config_lib.ReadConfig,
437440
shuffle_files: bool,
438441
disable_shuffling: bool = False,
439-
decode_fn: Optional[DecodeFn] = None,
442+
decode_fn: DecodeFn | None = None,
440443
) -> tf.data.Dataset:
441444
"""Returns single tf.data.Dataset instance for the set of file instructions.
442445

tensorflow_datasets/core/reader_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def setUp(self):
7676
with mock.patch.object(
7777
example_parser, 'ExampleParser', testing.DummyParser
7878
):
79-
self.reader = reader_lib.Reader(self.tmp_dir, 'some_spec')
79+
self.reader = reader_lib.Reader(self.tmp_dir, 'tfrecord')
8080
self.reader.read = functools.partial(
8181
self.reader.read,
8282
read_config=read_config_lib.ReadConfig(),

tensorflow_datasets/core/splits.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import abc
21-
from collections.abc import Iterable
21+
from collections.abc import Iterable, Sequence
2222
import dataclasses
2323
import functools
2424
import itertools
@@ -123,7 +123,7 @@ def __post_init__(self):
123123
def get_available_shards(
124124
self,
125125
data_dir: epath.Path | None = None,
126-
file_format: file_adapters.FileFormat | None = None,
126+
file_format: str | file_adapters.FileFormat | None = None,
127127
strict_matching: bool = True,
128128
) -> list[epath.Path]:
129129
"""Returns the list of shards that are present in the data dir.
@@ -140,6 +140,7 @@ def get_available_shards(
140140
"""
141141
if filename_template := self.filename_template:
142142
if file_format:
143+
file_format = file_adapters.FileFormat.from_value(file_format)
143144
filename_template = filename_template.replace(
144145
filetype_suffix=file_format.file_suffix
145146
)
@@ -250,7 +251,9 @@ def replace(self, **kwargs: Any) -> SplitInfo:
250251
"""Returns a copy of the `SplitInfo` with updated attributes."""
251252
return dataclasses.replace(self, **kwargs)
252253

253-
def file_spec(self, file_format: file_adapters.FileFormat) -> str:
254+
def file_spec(
255+
self, file_format: str | file_adapters.FileFormat
256+
) -> str | None:
254257
"""Returns the file spec of the split for the given file format.
255258
256259
A file spec is the full path with sharded notation, e.g.,
@@ -259,6 +262,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
259262
Args:
260263
file_format: the file format for which to create the file spec for.
261264
"""
265+
file_format = file_adapters.FileFormat.from_value(file_format)
262266
if filename_template := self.filename_template:
263267
if filename_template.filetype_suffix != file_format.file_suffix:
264268
raise ValueError(
@@ -268,9 +272,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
268272
return filename_template.sharded_filepaths_pattern(
269273
num_shards=self.num_shards
270274
)
271-
raise ValueError(
272-
f'Could not get filename template for split from split info: {self}.'
273-
)
275+
return None
274276

275277

276278
@dataclasses.dataclass(eq=False, frozen=True)
@@ -425,7 +427,7 @@ def __repr__(self) -> str:
425427
if typing.TYPE_CHECKING:
426428
# For type checking, `tfds.Split` is an alias for `str` with additional
427429
# `.TRAIN`, `.TEST`,... attributes. All strings are valid split type.
428-
Split = Union[Split, str]
430+
Split = Split | str
429431

430432

431433
class SplitDict(utils.NonMutableDict[str, SplitInfo]):
@@ -438,7 +440,7 @@ def __init__(
438440
# TODO(b/216470058): remove this parameter
439441
dataset_name: str | None = None, # deprecated, please don't use
440442
):
441-
super(SplitDict, self).__init__(
443+
super().__init__(
442444
{split_info.name: split_info for split_info in split_infos},
443445
error_msg='Split {key} already present',
444446
)
@@ -457,7 +459,7 @@ def __getitem__(self, key) -> SplitInfo | SubSplitInfo:
457459
)
458460
# 1st case: The key exists: `info.splits['train']`
459461
elif str(key) in self.keys():
460-
return super(SplitDict, self).__getitem__(str(key))
462+
return super().__getitem__(str(key))
461463
# 2nd case: Uses instructions: `info.splits['train[50%]']`
462464
else:
463465
instructions = _make_file_instructions(
@@ -543,7 +545,7 @@ def _file_instructions_for_split(
543545

544546

545547
def _make_file_instructions(
546-
split_infos: list[SplitInfo],
548+
split_infos: Sequence[SplitInfo],
547549
instruction: SplitArg,
548550
) -> list[shard_utils.FileInstruction]:
549551
"""Returns file instructions by applying the given instruction on the given splits.
@@ -587,7 +589,7 @@ class AbstractSplit(abc.ABC):
587589
"""
588590

589591
@classmethod
590-
def from_spec(cls, spec: SplitArg) -> 'AbstractSplit':
592+
def from_spec(cls, spec: SplitArg) -> AbstractSplit:
591593
"""Creates a ReadInstruction instance out of a string spec.
592594
593595
Args:
@@ -632,7 +634,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]:
632634
"""
633635
raise NotImplementedError
634636

635-
def __add__(self, other: Union[str, 'AbstractSplit']) -> 'AbstractSplit':
637+
def __add__(self, other: str | AbstractSplit) -> AbstractSplit:
636638
"""Sum of 2 splits."""
637639
if not isinstance(other, (str, AbstractSplit)):
638640
raise TypeError(f'Adding split {self!r} with non-split value: {other!r}')

tensorflow_datasets/core/splits_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,10 +666,11 @@ def test_file_spec_missing_template(self):
666666
num_bytes=42,
667667
filename_template=None,
668668
)
669-
with self.assertRaises(ValueError):
670-
split_info.file_spec(
671-
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
672-
)
669+
self.assertIsNone(
670+
split_info.file_spec(
671+
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
672+
)
673+
)
673674

674675
def test_get_available_shards(self):
675676
tmp_dir = epath.Path(self.tmp_dir)

0 commit comments

Comments
 (0)