|
23 | 23 | import types |
24 | 24 | import unittest |
25 | 25 | from collections.abc import Callable, Iterator |
26 | | -from typing import TYPE_CHECKING, Generic, NoReturn, ParamSpec, TypeVar, cast, overload |
| 26 | +from typing import ( |
| 27 | + TYPE_CHECKING, |
| 28 | + Any, |
| 29 | + Generic, |
| 30 | + NoReturn, |
| 31 | + ParamSpec, |
| 32 | + TypeVar, |
| 33 | + cast, |
| 34 | + overload, |
| 35 | +) |
27 | 36 | from unittest.case import SkipTest |
28 | 37 |
|
29 | 38 | T = TypeVar("T") |
@@ -562,7 +571,16 @@ def assertIsInstance( # type: ignore[override] |
562 | 571 | @overload # type: ignore[override] |
563 | 572 | def assertRaises( |
564 | 573 | self, |
565 | | - expected_exception: type[BaseException] | tuple[type[BaseException]], |
| 574 | + expected_exception: type[_E], |
| 575 | + callable: Callable[_P, _R], |
| 576 | + *args: _P.args, |
| 577 | + **kwargs: _P.kwargs, |
| 578 | + ) -> _E: ... |
| 579 | + |
| 580 | + @overload |
| 581 | + def assertRaises( |
| 582 | + self, |
| 583 | + expected_exception: tuple[type[BaseException], ...], |
566 | 584 | callable: Callable[_P, _R], |
567 | 585 | *args: _P.args, |
568 | 586 | **kwargs: _P.kwargs, |
@@ -596,13 +614,13 @@ def assertRaises( |
596 | 614 | callable: None = ..., |
597 | 615 | ) -> _AssertRaisesContext[BaseException]: ... |
598 | 616 |
|
599 | | - def assertRaises( # type: ignore[misc] |
| 617 | + def assertRaises( |
600 | 618 | self, |
601 | | - expected_exception: type[BaseException] | tuple[type[BaseException], ...], |
| 619 | + expected_exception: type[_E] | tuple[type[BaseException], ...], |
602 | 620 | callable: Callable[_P, _R] | None = None, |
603 | 621 | *args: _P.args, |
604 | 622 | **kwargs: _P.kwargs, |
605 | | - ) -> _AssertRaisesContext[BaseException] | BaseException: |
| 623 | + ) -> _AssertRaisesContext[Any] | BaseException: |
606 | 624 | """Fail unless an exception of class expected_exception is thrown |
607 | 625 | by callable when invoked with arguments args and keyword |
608 | 626 | arguments kwargs. If a different type of exception is |
@@ -670,6 +688,7 @@ def match( |
670 | 688 | ) |
671 | 689 | our_callable: Callable[[], object] = Nullary(callable, *args, **kwargs) |
672 | 690 | self.assertThat(our_callable, matcher) |
| 691 | + # we know that we have the right exception type now |
673 | 692 | return capture.matchee |
674 | 693 |
|
675 | 694 | def assertThat( |
|
0 commit comments