@@ -436,14 +436,32 @@ def assertIsInstance(self, obj, klass, msg=None):
436436 matcher = IsInstance (klass )
437437 self .assertThat (obj , matcher , msg )
438438
439- def assertRaises (self , excClass , callableObj , * args , ** kwargs ):
439+ def assertRaises (self , excClass , callableObj = None , * args , ** kwargs ):
440440 """Fail unless an exception of class excClass is thrown
441441 by callableObj when invoked with arguments args and keyword
442442 arguments kwargs. If a different type of exception is
443443 thrown, it will not be caught, and the test case will be
444444 deemed to have suffered an error, exactly as for an
445445 unexpected exception.
446+
447+ If called with the callable omitted, will return a
448+ context object used like this::
449+
450+ with self.assertRaises(SomeException):
451+ do_something()
452+
453+ The context manager keeps a reference to the exception as
454+ the 'exception' attribute. This allows you to inspect the
455+ exception after the assertion::
456+
457+ with self.assertRaises(SomeException) as cm:
458+ do_something()
459+ the_exception = cm.exception
460+ self.assertEqual(the_exception.error_code, 3)
446461 """
462+ # If callableObj is None, we're being used as a context manager
463+ if callableObj is None :
464+ return _AssertRaisesContext (excClass , self , msg = kwargs .get ("msg" ))
447465
448466 class ReRaiseOtherTypes :
449467 def match (self , matchee ):
@@ -1011,6 +1029,51 @@ def _id(obj):
10111029 return _id
10121030
10131031
1032+ class _AssertRaisesContext :
1033+ """A context manager to handle expected exceptions for assertRaises.
1034+
1035+ This provides compatibility with unittest's assertRaises context manager.
1036+ """
1037+
1038+ def __init__ (self , expected , test_case , msg = None ):
1039+ """Construct an `_AssertRaisesContext`.
1040+
1041+ :param expected: The type of exception to expect.
1042+ :param test_case: The TestCase instance using this context.
1043+ :param msg: An optional message explaining the failure.
1044+ """
1045+ self .expected = expected
1046+ self .test_case = test_case
1047+ self .msg = msg
1048+ self .exception = None
1049+
1050+ def __enter__ (self ):
1051+ return self
1052+
1053+ def __exit__ (self , exc_type , exc_value , traceback ):
1054+ if exc_type is None :
1055+ try :
1056+ if isinstance (self .expected , tuple ):
1057+ exc_name = "({})" .format (
1058+ ", " .join (e .__name__ for e in self .expected )
1059+ )
1060+ else :
1061+ exc_name = self .expected .__name__
1062+ except AttributeError :
1063+ exc_name = str (self .expected )
1064+ if self .msg :
1065+ error_msg = "{} not raised : {}" .format (exc_name , self .msg )
1066+ else :
1067+ error_msg = "{} not raised" .format (exc_name )
1068+ raise self .test_case .failureException (error_msg )
1069+ if not issubclass (exc_type , self .expected ):
1070+ # let unexpected exceptions pass through
1071+ return False
1072+ # store exception for later retrieval
1073+ self .exception = exc_value
1074+ return True
1075+
1076+
10141077class ExpectedException :
10151078 """A context manager to handle expected exceptions.
10161079
0 commit comments