17
17
18
18
import contextlib
19
19
import dataclasses
20
- import threading
21
20
from typing import Callable , Iterator , List , Type , Union
22
21
22
+ from etils import edc
23
23
from tensorflow_datasets .core import utils
24
24
25
25
Message = Union [str , Callable [[], str ]]
@@ -38,7 +38,13 @@ class ErrorContext:
38
38
39
39
40
40
# Current error context. Accessed by `reraise_with_context` and `add_context`.
41
- context_holder = threading .local ()
41
+ @edc .dataclass
42
+ @dataclasses .dataclass
43
+ class ContextHolder :
44
+ current_context_msg : edc .ContextVar [ErrorContext | None ] = None
45
+
46
+
47
+ context_holder = ContextHolder ()
42
48
43
49
44
50
@contextlib .contextmanager
@@ -53,7 +59,7 @@ def reraise_with_context(error_cls: Type[Exception]) -> Iterator[None]:
53
59
"""
54
60
# If current_context_msg exists, we are already within the scope of the
55
61
# session contextmanager.
56
- if hasattr ( context_holder , ' current_context_msg' ) :
62
+ if context_holder . current_context_msg is not None :
57
63
yield
58
64
return
59
65
@@ -64,7 +70,7 @@ def reraise_with_context(error_cls: Type[Exception]) -> Iterator[None]:
64
70
context_msg = '\n ' .join (context_holder .current_context_msg .messages )
65
71
utils .reraise (e , suffix = context_msg )
66
72
finally :
67
- del context_holder .current_context_msg
73
+ context_holder .current_context_msg = None
68
74
69
75
70
76
def add_context (msg : str ) -> None :
@@ -79,7 +85,7 @@ def add_context(msg: str) -> None:
79
85
Raises:
80
86
AttributeError if local thread has no current_context_msg attribute.
81
87
"""
82
- if not hasattr ( context_holder , ' current_context_msg' ) :
88
+ if context_holder . current_context_msg is None :
83
89
raise AttributeError (
84
90
'add_context called outside of reraise_with_context contextmanager.'
85
91
)
0 commit comments