Skip to content

Commit f798fe9

Browse files
pierrot0The TensorFlow Datasets Authors
authored andcommitted
Wrap metadata read file operations into a retry.
PiperOrigin-RevId: 767092333
1 parent 374c3c5 commit f798fe9

File tree

6 files changed

+98
-14
lines changed

6 files changed

+98
-14
lines changed

tensorflow_datasets/core/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,28 @@
4949
# Filepath for mapping between TFDS datasets and PapersWithCode entries.
5050
PWC_FILENAME = 'tfds_to_pwc_links.json'
5151
PWC_LINKS_PATH = f'scripts/documentation/{PWC_FILENAME}'
52+
53+
# Retry parameters. Delays are in seconds.
54+
TFDS_RETRY_TRIES = int(os.environ.get('TFDS_RETRY_TRIES', 3))
55+
TFDS_RETRY_INITIAL_DELAY = int(os.environ.get('TFDS_RETRY_INITIAL_DELAY', 1))
56+
# How much to multiply the delay by for each subsequent try
57+
TFDS_RETRY_DELAY_MULTIPLIER = int(
58+
os.environ.get('TFDS_RETRY_DELAY_MULTIPLIER', 2)
59+
)
60+
# Random noise to add to the delay (random pick between 0 and noise).
61+
TFDS_RETRY_NOISE = float(os.environ.get('TFDS_RETRY_NOISE', 0.5))
62+
# If the error message contains any of these substrings, retry.
63+
TFDS_RETRY_MSG_SUBSTRINGS = os.environ.get(
64+
'TFDS_RETRY_MSG_SUBSTRINGS',
65+
(
66+
'deadline_exceeded,'
67+
'408 Request Timeout,'
68+
'429 Too Many Requests,'
69+
'500 Internal Server Error,'
70+
'502 Bad Gateway,'
71+
'503 Service Unavailable,'
72+
'504 Gateway Timeout,'
73+
'509 Bandwidth Limit Exceeded,'
74+
'599 Gateway Error'
75+
),
76+
).split(',')

tensorflow_datasets/core/dataset_builder.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from tensorflow_datasets.core.utils import file_utils
6565
from tensorflow_datasets.core.utils import gcs_utils
6666
from tensorflow_datasets.core.utils import read_config as read_config_lib
67+
from tensorflow_datasets.core.utils import retry
6768
from tensorflow_datasets.core.utils import type_utils
6869
# pylint: enable=g-import-not-at-top
6970

@@ -290,7 +291,8 @@ def __init__(
290291
# Compute the base directory (for download) and dataset/version directory.
291292
self._data_dir_root, self._data_dir = self._build_data_dir(data_dir)
292293
# If the dataset info is available, use it.
293-
if dataset_info.dataset_info_path(self.data_path).exists():
294+
dataset_info_path = dataset_info.dataset_info_path(self.data_path)
295+
if retry.retry(dataset_info_path.exists):
294296
self.info.read_from_directory(self._data_dir)
295297
else: # Use the code version (do not restore data)
296298
self.info.initialize_from_bucket()
@@ -466,8 +468,8 @@ def _checksums_path(cls) -> epath.Path | None:
466468
# zipfile.Path does not have `.parts`. Additionally, `os.fspath`
467469
# will extract the file, so use `str`.
468470
"tensorflow_datasets" in str(new_path)
469-
and legacy_path.exists()
470-
and not new_path.exists()
471+
and retry.retry(legacy_path.exists)
472+
and not retry.retry(new_path.exists)
471473
):
472474
return legacy_path
473475
else:
@@ -484,7 +486,7 @@ def url_infos(cls) -> dict[str, download.checksums.UrlInfo] | None:
484486
# Search for the url_info file.
485487
checksums_path = cls._checksums_path
486488
# If url_info file is found, load the urls
487-
if checksums_path and checksums_path.exists():
489+
if checksums_path and retry.retry(checksums_path.exists):
488490
return download.checksums.load_url_infos(checksums_path)
489491
else:
490492
return None
@@ -624,7 +626,7 @@ def download_and_prepare(
624626

625627
download_config = download_config or download.DownloadConfig()
626628
data_path = self.data_path
627-
data_exists = data_path.exists()
629+
data_exists = retry.retry(data_path.exists)
628630

629631
# Saving nondeterministic_order in the DatasetInfo for documentation.
630632
if download_config.nondeterministic_order:
@@ -640,7 +642,7 @@ def download_and_prepare(
640642
"Deleting pre-existing dataset %s (%s)", self.name, self.data_dir
641643
)
642644
data_path.rmtree() # Delete pre-existing data.
643-
data_exists = data_path.exists()
645+
data_exists = retry.retry(data_path.exists)
644646
else:
645647
logging.info("Reusing dataset %s (%s)", self.name, self.data_dir)
646648
return
@@ -805,7 +807,7 @@ def _post_download_and_prepare_hook(self) -> None:
805807
def _update_dataset_info(self) -> None:
806808
"""Updates the `dataset_info.json` file in the dataset dir."""
807809
info_file = self.data_path / constants.DATASET_INFO_FILENAME
808-
if not info_file.exists():
810+
if not retry.retry(info_file.exists):
809811
raise AssertionError(f"To update {info_file}, it must already exist.")
810812
new_info = self.info
811813
new_info.read_from_directory(self.data_path)
@@ -1020,7 +1022,7 @@ def as_dataset(
10201022
self.assert_is_not_blocked()
10211023

10221024
# pylint: enable=line-too-long
1023-
if not self.data_path.exists():
1025+
if not retry.retry(self.data_path.exists):
10241026
raise AssertionError(
10251027
"Dataset %s: could not find data in %s. Please make sure to call "
10261028
"dataset_builder.download_and_prepare(), or pass download=True to "
@@ -1817,7 +1819,7 @@ def read_text_file(
18171819
"""Returns the text in the given file and records the lineage."""
18181820
filename = epath.Path(filename)
18191821
self.info.add_file_data_source_access(filename)
1820-
return filename.read_text(encoding=encoding)
1822+
return retry.retry(filename.read_text, encoding=encoding)
18211823

18221824
def read_tfrecord_as_dataset(
18231825
self,
@@ -2057,9 +2059,9 @@ def _save_default_config_name(
20572059
def load_default_config_name(builder_dir: epath.Path) -> str | None:
20582060
"""Load `builder_cls` metadata (common to all builder configs)."""
20592061
config_path = builder_dir / ".config" / constants.METADATA_FILENAME
2060-
if not config_path.exists():
2062+
if not retry.retry(config_path.exists):
20612063
return None
2062-
data = json.loads(config_path.read_text())
2064+
data = json.loads(retry.retry(config_path.read_text))
20632065
return data.get("default_config_name")
20642066

20652067

tensorflow_datasets/core/dataset_info.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
# pylint: disable=g-import-not-at-top
6262
from tensorflow_datasets.core.utils import file_utils
6363
from tensorflow_datasets.core.utils import gcs_utils
64+
from tensorflow_datasets.core.utils import retry
6465

6566
from google.protobuf import json_format
6667
# pylint: enable=g-import-not-at-top
@@ -1123,7 +1124,7 @@ def read_from_json(path: epath.PathLike) -> dataset_info_pb2.DatasetInfo:
11231124
DatasetInfoFileError: If the dataset info file cannot be read.
11241125
"""
11251126
try:
1126-
json_str = epath.Path(path).read_text()
1127+
json_str = retry.retry(epath.Path(path).read_text)
11271128
except OSError as e:
11281129
raise DatasetInfoFileError(
11291130
f"Could not read dataset info from {path}"

tensorflow_datasets/core/features/feature.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from tensorflow_datasets.core.utils import dtype_utils
3838
from tensorflow_datasets.core.utils import np_utils
3939
from tensorflow_datasets.core.utils import py_utils
40+
from tensorflow_datasets.core.utils import retry
4041
from tensorflow_datasets.core.utils import tf_utils
4142
from tensorflow_datasets.core.utils import type_utils
4243
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
@@ -658,7 +659,7 @@ def from_config(cls, root_dir: str) -> FeatureConnector:
658659
Returns:
659660
The reconstructed feature instance.
660661
"""
661-
content = json.loads(make_config_path(root_dir).read_text())
662+
content = json.loads(retry.retry(make_config_path(root_dir).read_text))
662663
feature = FeatureConnector.from_json(content)
663664
feature.load_metadata(root_dir, feature_name=None)
664665
return feature

tensorflow_datasets/core/splits.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensorflow_datasets.core import proto as proto_lib
3737
from tensorflow_datasets.core import units
3838
from tensorflow_datasets.core import utils
39+
from tensorflow_datasets.core.utils import retry
3940
from tensorflow_datasets.core.utils import shard_utils
4041

4142
from tensorflow_metadata.proto.v0 import statistics_pb2
@@ -149,7 +150,7 @@ def get_available_shards(
149150
pattern = filename_template.glob_pattern(num_shards=self.num_shards)
150151
else:
151152
pattern = filename_template.sharded_filepaths_pattern(num_shards=None)
152-
return list(data_dir.glob(pattern))
153+
return list(retry.retry(data_dir.glob, pattern))
153154
else:
154155
raise ValueError(f'Filename template for split {self.name} is empty.')
155156

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# coding=utf-8
2+
# Copyright 2025 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""To add retry logic to operations suceptible to transient failures."""
17+
18+
import random
19+
import time
20+
from typing import Callable, ParamSpec, TypeVar
21+
22+
from absl import logging
23+
from tensorflow_datasets.core import constants
24+
25+
26+
P = ParamSpec("P")
27+
T = TypeVar("T")
28+
29+
30+
def retry(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
31+
"""Returns a decorator that retries the function."""
32+
# We purposely don't use flags, as this code might be run before flags are
33+
# parsed.
34+
tries = constants.TFDS_RETRY_TRIES
35+
delay = constants.TFDS_RETRY_INITIAL_DELAY
36+
multiplier = constants.TFDS_RETRY_DELAY_MULTIPLIER
37+
noise = constants.TFDS_RETRY_NOISE
38+
msg_substrings = constants.TFDS_RETRY_MSG_SUBSTRINGS
39+
for trial in range(1, tries + 1):
40+
try:
41+
return func(*args, **kwargs)
42+
except BaseException as err: # pylint: disable=broad-except
43+
if trial >= tries:
44+
raise err
45+
msg = str(err)
46+
for msg_substring in msg_substrings:
47+
if msg_substring in msg:
48+
break
49+
else:
50+
raise err
51+
delay = delay + random.uniform(0, noise)
52+
logging.warning("%s, retrying in %s seconds...", msg, delay)
53+
time.sleep(delay)
54+
delay *= multiplier

0 commit comments

Comments
 (0)