Skip to content

Commit ddf9764

Browse files
committed
typing: Split assertRaises context manager overload
This ensures we don't lose type information on the `exception` attribute of the context manager: with self.assertRaises(ValueError) as ctx: int("foo") reveal_type(ctx.exception) # now reveals ValueError We also narrow the tuple type, but varadic expansion is not a thing yet, so we need explicitly type for different lengths of tuple. We decide that at most 3 exceptions can be passed: any more and you'll lose this information, which seems fair. Signed-off-by: Stephen Finucane <stephen@that.guru>
1 parent c51db6d commit ddf9764

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

testtools/testcase.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
T = TypeVar("T")
3030
U = TypeVar("U")
31+
_E = TypeVar("_E", bound=BaseException)
32+
_E2 = TypeVar("_E2", bound=BaseException)
33+
_E3 = TypeVar("_E3", bound=BaseException)
3134

3235
# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules
3336
from testtools import content
@@ -513,17 +516,38 @@ def assertRaises(
513516
@overload # type: ignore[override]
514517
def assertRaises(
515518
self,
516-
expected_exception: type[BaseException] | tuple[type[BaseException]],
519+
expected_exception: type[_E],
517520
callable: None = ...,
518-
) -> "_AssertRaisesContext": ...
521+
) -> "_AssertRaisesContext[_E]": ...
519522

520-
def assertRaises( # type: ignore[override]
523+
@overload # type: ignore[override]
524+
def assertRaises(
521525
self,
522-
expected_exception: type[BaseException] | tuple[type[BaseException]],
526+
expected_exception: tuple[type[_E], type[_E2]],
527+
callable: None = ...,
528+
) -> "_AssertRaisesContext[_E | _E2]": ...
529+
530+
@overload # type: ignore[override]
531+
def assertRaises(
532+
self,
533+
expected_exception: tuple[type[_E], type[_E2], type[_E3]],
534+
callable: None = ...,
535+
) -> "_AssertRaisesContext[_E | _E2 | _E3]": ...
536+
537+
@overload # type: ignore[override]
538+
def assertRaises(
539+
self,
540+
expected_exception: tuple[type[BaseException], ...],
541+
callable: None = ...,
542+
) -> "_AssertRaisesContext[BaseException]": ...
543+
544+
def assertRaises( # type: ignore[override, misc]
545+
self,
546+
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
523547
callable: Callable[_P, _R] | None = None,
524548
*args: _P.args,
525549
**kwargs: _P.kwargs,
526-
) -> "_AssertRaisesContext | BaseException":
550+
) -> "_AssertRaisesContext[BaseException] | BaseException":
527551
"""Fail unless an exception of class expected_exception is thrown
528552
by callable when invoked with arguments args and keyword
529553
arguments kwargs. If a different type of exception is
@@ -1223,15 +1247,15 @@ def _id(obj: _F) -> _F:
12231247
return _id
12241248

12251249

1226-
class _AssertRaisesContext:
1250+
class _AssertRaisesContext(Generic[_E]):
12271251
"""A context manager to handle expected exceptions for assertRaises.
12281252
12291253
This provides compatibility with unittest's assertRaises context manager.
12301254
"""
12311255

12321256
def __init__(
12331257
self,
1234-
expected: type[BaseException] | tuple[type[BaseException]],
1258+
expected: type[_E] | tuple[type[BaseException], ...],
12351259
test_case: TestCase,
12361260
msg: str | None = None,
12371261
) -> None:
@@ -1244,7 +1268,7 @@ def __init__(
12441268
self.expected = expected
12451269
self.test_case = test_case
12461270
self.msg = msg
1247-
self.exception: BaseException | None = None
1271+
self.exception: _E | None = None
12481272

12491273
def __enter__(self) -> "Self":
12501274
return self
@@ -1274,7 +1298,7 @@ def __exit__(
12741298
# let unexpected exceptions pass through
12751299
return False
12761300
# store exception for later retrieval
1277-
self.exception = exc_value
1301+
self.exception = cast(_E, exc_value)
12781302
return True
12791303

12801304

0 commit comments

Comments
 (0)