Skip to content

Commit 17e28ea

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add a wrapper around itertools.tee to make it a thread-safe
When writing to multiple output formats, users ran into problems. A runtime exception was raised: `generator raised StopIteration`. This is caused by itertools.tee not being thread safe: https://docs.python.org/3/library/itertools.html#itertools.tee PiperOrigin-RevId: 664708727
1 parent cd0fa0c commit 17e28ea

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

tensorflow_datasets/core/writer.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
from __future__ import annotations
1919

20-
from collections.abc import Iterable, Sequence
20+
from collections.abc import Iterable, Iterator, Sequence
2121
import dataclasses
2222
import functools
2323
import itertools
2424
import json
2525
import os
26+
import threading
2627
from typing import Any
2728

2829
from etils import epy
@@ -190,6 +191,38 @@ def write(
190191
return adapter.write_examples(path, examples)
191192

192193

194+
class ThreadSafeIterator(Iterator):
195+
"""A wrapper around a tee object to make it thread-safe.
196+
197+
See https://stackoverflow.com/q/6703594 for more details.
198+
"""
199+
200+
def __init__(self, tee_object: Any, lock: threading.Lock):
201+
self._tee_object = tee_object
202+
self._lock = lock
203+
204+
def __iter__(self):
205+
return self
206+
207+
def __next__(self):
208+
with self._lock:
209+
return next(self._tee_object)
210+
211+
def __copy__(self):
212+
return ThreadSafeIterator(self._tee_object.__copy__(), self._lock)
213+
214+
215+
def thread_safe_tee(
216+
iterable: Iterable[Any], n: int
217+
) -> tuple[ThreadSafeIterator, ...]:
218+
"""Returns a tuple of n independent thread-safe iterators."""
219+
lock = threading.Lock()
220+
return tuple(
221+
ThreadSafeIterator(tee_object, lock)
222+
for tee_object in itertools.tee(iterable, n)
223+
)
224+
225+
193226
class MultiOutputExampleWriter(ExampleWriter):
194227
"""Example writer that can write multiple outputs."""
195228

@@ -207,7 +240,7 @@ def write(
207240
"""Writes examples to multiple outputs."""
208241
write_fns = []
209242
for writer, my_iter in zip(
210-
self._writers, itertools.tee(examples, len(self._writers))
243+
self._writers, thread_safe_tee(examples, len(self._writers))
211244
):
212245
if file_format := writer.file_format:
213246
shard_path = os.fspath(

0 commit comments

Comments
 (0)