2828
2929T = TypeVar ("T" )
3030U = TypeVar ("U" )
31+ _E = TypeVar ("_E" , bound = BaseException )
32+ _E2 = TypeVar ("_E2" , bound = BaseException )
33+ _E3 = TypeVar ("_E3" , bound = BaseException )
3134
3235# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules
3336from testtools import content
@@ -513,17 +516,38 @@ def assertRaises(
513516 @overload # type: ignore[override]
514517 def assertRaises (
515518 self ,
516- expected_exception : type [BaseException ] | tuple [ type [ BaseException ] ],
519+ expected_exception : type [_E ],
517520 callable : None = ...,
518- ) -> "_AssertRaisesContext" : ...
521+ ) -> "_AssertRaisesContext[_E] " : ...
519522
520- def assertRaises ( # type: ignore[override]
523+ @overload # type: ignore[override]
524+ def assertRaises (
521525 self ,
522- expected_exception : type [BaseException ] | tuple [type [BaseException ]],
526+ expected_exception : tuple [type [_E ], type [_E2 ]],
527+ callable : None = ...,
528+ ) -> "_AssertRaisesContext[_E | _E2]" : ...
529+
530+ @overload # type: ignore[override]
531+ def assertRaises (
532+ self ,
533+ expected_exception : tuple [type [_E ], type [_E2 ], type [_E3 ]],
534+ callable : None = ...,
535+ ) -> "_AssertRaisesContext[_E | _E2 | _E3]" : ...
536+
537+ @overload # type: ignore[override]
538+ def assertRaises (
539+ self ,
540+ expected_exception : tuple [type [BaseException ], ...],
541+ callable : None = ...,
542+ ) -> "_AssertRaisesContext[BaseException]" : ...
543+
544+ def assertRaises ( # type: ignore[override, misc]
545+ self ,
546+ expected_exception : type [BaseException ] | tuple [type [BaseException ], ...],
523547 callable : Callable [_P , _R ] | None = None ,
524548 * args : _P .args ,
525549 ** kwargs : _P .kwargs ,
526- ) -> "_AssertRaisesContext | BaseException" :
550+ ) -> "_AssertRaisesContext[BaseException] | BaseException" :
527551 """Fail unless an exception of class expected_exception is thrown
528552 by callable when invoked with arguments args and keyword
529553 arguments kwargs. If a different type of exception is
@@ -1223,15 +1247,15 @@ def _id(obj: _F) -> _F:
12231247 return _id
12241248
12251249
1226- class _AssertRaisesContext :
1250+ class _AssertRaisesContext ( Generic [ _E ]) :
12271251 """A context manager to handle expected exceptions for assertRaises.
12281252
12291253 This provides compatibility with unittest's assertRaises context manager.
12301254 """
12311255
12321256 def __init__ (
12331257 self ,
1234- expected : type [BaseException ] | tuple [type [BaseException ]],
1258+ expected : type [_E ] | tuple [type [BaseException ], ... ],
12351259 test_case : TestCase ,
12361260 msg : str | None = None ,
12371261 ) -> None :
@@ -1244,7 +1268,7 @@ def __init__(
12441268 self .expected = expected
12451269 self .test_case = test_case
12461270 self .msg = msg
1247- self .exception : BaseException | None = None
1271+ self .exception : _E | None = None
12481272
12491273 def __enter__ (self ) -> "Self" :
12501274 return self
@@ -1274,7 +1298,7 @@ def __exit__(
12741298 # let unexpected exceptions pass through
12751299 return False
12761300 # store exception for later retrieval
1277- self .exception = exc_value
1301+ self .exception = cast ( _E , exc_value )
12781302 return True
12791303
12801304
0 commit comments