Skip to content

Commit 3998279

Browse files
committed
Correct type for assertRaises
Correct two mistakes: * We need an ellipsis to make this a vararg tuple. * Narrowing to a TypeVar is not correct when we're passed a tuple Signed-off-by: Stephen Finucane <stephen@that.guru>
1 parent 5784965 commit 3998279

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

testtools/testcase.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@
2323
import types
2424
import unittest
2525
from collections.abc import Callable, Iterator
26-
from typing import TYPE_CHECKING, Generic, NoReturn, ParamSpec, TypeVar, cast, overload
26+
from typing import (
27+
TYPE_CHECKING,
28+
Any,
29+
Generic,
30+
NoReturn,
31+
ParamSpec,
32+
TypeVar,
33+
cast,
34+
overload,
35+
)
2736
from unittest.case import SkipTest
2837

2938
T = TypeVar("T")
@@ -562,7 +571,16 @@ def assertIsInstance( # type: ignore[override]
562571
@overload # type: ignore[override]
563572
def assertRaises(
564573
self,
565-
expected_exception: type[BaseException] | tuple[type[BaseException]],
574+
expected_exception: type[_E],
575+
callable: Callable[_P, _R],
576+
*args: _P.args,
577+
**kwargs: _P.kwargs,
578+
) -> _E: ...
579+
580+
@overload
581+
def assertRaises(
582+
self,
583+
expected_exception: tuple[type[BaseException], ...],
566584
callable: Callable[_P, _R],
567585
*args: _P.args,
568586
**kwargs: _P.kwargs,
@@ -596,13 +614,13 @@ def assertRaises(
596614
callable: None = ...,
597615
) -> _AssertRaisesContext[BaseException]: ...
598616

599-
def assertRaises( # type: ignore[misc]
617+
def assertRaises(
600618
self,
601-
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
619+
expected_exception: type[_E] | tuple[type[BaseException], ...],
602620
callable: Callable[_P, _R] | None = None,
603621
*args: _P.args,
604622
**kwargs: _P.kwargs,
605-
) -> _AssertRaisesContext[BaseException] | BaseException:
623+
) -> _AssertRaisesContext[Any] | BaseException:
606624
"""Fail unless an exception of class expected_exception is thrown
607625
by callable when invoked with arguments args and keyword
608626
arguments kwargs. If a different type of exception is
@@ -670,6 +688,7 @@ def match(
670688
)
671689
our_callable: Callable[[], object] = Nullary(callable, *args, **kwargs)
672690
self.assertThat(our_callable, matcher)
691+
# we know that we have the right exception type now
673692
return capture.matchee
674693

675694
def assertThat(

0 commit comments

Comments
 (0)