Skip to content

Commit c1b13d1

Browse files
author
The TensorFlow Datasets Authors
committed
Add support for multi-threaded use of reraise_with_context.
PiperOrigin-RevId: 673306041
1 parent a691e0a commit c1b13d1

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

tensorflow_datasets/core/read_only_builder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,20 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> str | None:
356356
all_builder_dirs.add(builder_dir)
357357
else:
358358
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
359-
for builder_dir in executor.map(find_builder_fn, all_data_dirs):
359+
# Keep track of each new thread's error context, and add it to the main
360+
# error context when each thread finishes.
361+
def wrapped_find_builder_fn(data_dir):
362+
with error_utils.record_error_context() as thread_context:
363+
builder_dir = find_builder_fn(data_dir)
364+
return thread_context, builder_dir
365+
366+
for context, builder_dir in executor.map(
367+
wrapped_find_builder_fn, all_data_dirs
368+
):
360369
if builder_dir:
361370
all_builder_dirs.add(builder_dir)
371+
for msg in context.messages:
372+
error_utils.add_context(msg)
362373

363374
if not all_builder_dirs:
364375
all_dirs_str = '\n\t- '.join([''] + [str(dir) for dir in all_data_dirs])

tensorflow_datasets/core/utils/error_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,29 @@ class ErrorContext:
4141
@edc.dataclass
4242
@dataclasses.dataclass
4343
class ContextHolder:
44-
current_context_msg: ErrorContext | None = None
44+
# Each thread will use its own instance of current_context_msg.
45+
current_context_msg: edc.ContextVar[ErrorContext | None] = None
4546

4647

4748
context_holder = ContextHolder()
4849

4950

51+
@contextlib.contextmanager
52+
def record_error_context() -> Iterator[ErrorContext]:
53+
"""Contextmanager which captures the error context for a thread."""
54+
55+
if context_holder.current_context_msg is not None:
56+
raise ValueError(
57+
'Cannot record error context within the scope of another error context.'
58+
)
59+
60+
context_holder.current_context_msg = ErrorContext()
61+
try:
62+
yield context_holder.current_context_msg
63+
finally:
64+
context_holder.current_context_msg = None
65+
66+
5067
@contextlib.contextmanager
5168
def reraise_with_context(error_cls: Type[Exception]) -> Iterator[None]:
5269
"""Contextmanager which reraises an exception with an additional message.

0 commit comments

Comments
 (0)