2424
2525import re
2626import sys
27+ from collections .abc import Callable
2728from datetime import datetime , timedelta , timezone
2829from json import loads as json_loads
29- from typing import Any , Callable , Set
30+ from typing import Generic , NoReturn , TypeVar , cast
3031from warnings import WarningMessage , catch_warnings
3132
32- from typing_extensions import deprecated
33+ from typing_extensions import Self , deprecated
3334
35+ _E = TypeVar ("_E" , bound = BaseException )
3436
35- def fail (msg = None ):
37+
38+ def fail (msg : str = "" ) -> NoReturn :
3639 """Raise an AssertionError with the given message.
3740
3841 >>> fail("my message")
@@ -868,7 +871,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"):
868871 fail (msg_fmt .format (msg = msg , actual = actual , now = now ))
869872
870873
871- class AssertRaisesContext :
874+ class AssertRaisesContext ( Generic [ _E ]) :
872875 """A context manager to test for exceptions with certain properties.
873876
874877 When the context is left and no exception has been raised, an
@@ -904,36 +907,41 @@ class AssertRaisesContext:
904907
905908 """
906909
907- def __init__ (self , exception , msg_fmt = "{msg}" ):
910+ def __init__ (self , exception : type [ _E ] , msg_fmt : str = "{msg}" ) -> None :
908911 self .exception = exception
909912 self .msg_fmt = msg_fmt
910913 self ._exc_type = exception
911- self ._exc_val = None
912- self ._exception_name = getattr (exception , "__name__" , str (exception ))
913- self ._tests : list [Callable [[Any ], object ]] = []
914+ self ._exc_val : _E | None = None
915+ self ._exception_name : str = getattr (exception , "__name__" , str (exception ))
916+ self ._tests : list [Callable [[_E ], object ]] = []
914917
915- def __enter__ (self ):
918+ def __enter__ (self ) -> Self :
916919 return self
917920
918- def __exit__ (self , exc_type , exc_val , exc_tb ):
921+ def __exit__ (
922+ self ,
923+ exc_type : type [BaseException ] | None ,
924+ exc_val : BaseException | None ,
925+ exc_tb : object ,
926+ ) -> bool :
919927 if not exc_type or not exc_val :
920- msg = "{ } not raised". format ( self . _exception_name )
928+ msg = f" { self . _exception_name } not raised"
921929 fail (self .format_message (msg ))
922- self ._exc_val = exc_val
923930 if not issubclass (exc_type , self .exception ):
924931 return False
932+ self ._exc_val = cast (_E , exc_val )
925933 for test in self ._tests :
926- test (exc_val )
934+ test (self . _exc_val )
927935 return True
928936
929- def format_message (self , default_msg ) :
937+ def format_message (self , default_msg : str ) -> str :
930938 return self .msg_fmt .format (
931939 msg = default_msg ,
932940 exc_type = self ._exc_type ,
933941 exc_name = self ._exception_name ,
934942 )
935943
936- def add_test (self , cb : Callable [[Any ], object ]) -> None :
944+ def add_test (self , cb : Callable [[_E ], object ]) -> None :
937945 """Add a test callback.
938946
939947 This callback is called after determining that the right exception
@@ -944,7 +952,8 @@ class was raised. The callback will get the raised exception as only
944952 self ._tests .append (cb )
945953
946954 @property
947- def exc_val (self ):
955+ def exc_val (self ) -> _E :
956+ """The exception value that was raised within the context."""
948957 if self ._exc_val is None :
949958 raise RuntimeError ("must be called after leaving the context" )
950959 return self ._exc_val
@@ -1459,8 +1468,8 @@ def _assert_dict_values_equal(self) -> None:
14591468 self ._assert_json_value_equals_with_item (name )
14601469
14611470 @property
1462- def _expected_key_names (self ) -> Set [str ]:
1463- keys : Set [str ] = set ()
1471+ def _expected_key_names (self ) -> set [str ]:
1472+ keys : set [str ] = set ()
14641473 for k in self ._expected .keys ():
14651474 if isinstance (k , str ):
14661475 if not _is_absent (self ._expected [k ]):
0 commit comments