diff --git a/testtools/matchers/_basic.py b/testtools/matchers/_basic.py index b3451a53..2f40e259 100644 --- a/testtools/matchers/_basic.py +++ b/testtools/matchers/_basic.py @@ -73,6 +73,14 @@ def __init__(self, actual, mismatch_string, reference, reference_on_right=True): self._reference_on_right = reference_on_right def describe(self): + # Special handling for set comparisons + if ( + self._mismatch_string == "!=" + and isinstance(self._reference, set) + and isinstance(self._actual, set) + ): + return self._describe_set_difference() + actual = repr(self._actual) reference = repr(self._reference) if len(actual) + len(reference) > 70: @@ -88,6 +96,27 @@ def describe(self): left, right = reference, actual return f"{left} {self._mismatch_string} {right}" + def _describe_set_difference(self): + """Describe the difference between two sets in a readable format.""" + reference_only = sorted( + self._reference - self._actual, key=lambda x: (type(x).__name__, x) + ) + actual_only = sorted( + self._actual - self._reference, key=lambda x: (type(x).__name__, x) + ) + + lines = ["!=:"] + if reference_only: + lines.append( + f"Items in expected but not in actual:\n{_format(reference_only)}" + ) + if actual_only: + lines.append( + f"Items in actual but not in expected:\n{_format(actual_only)}" + ) + + return "\n".join(lines) + class Equals(_BinaryComparison): """Matches if the items are equal.""" diff --git a/testtools/tests/test_testcase.py b/testtools/tests/test_testcase.py index c4764310..c304af6a 100644 --- a/testtools/tests/test_testcase.py +++ b/testtools/tests/test_testcase.py @@ -863,6 +863,44 @@ def test_assertEqual_non_ascii_str_with_newlines(self): ) self.assertFails(expected_error, self.assertEqual, a, b, message) + def test_assertEqual_set_difference(self): + a = {1, 2, 3, 4} + b = {2, 3, 5, 6} + expected_error = "\n".join( + [ + "!=:", + "Items in expected but not in actual:", + "[1, 4]", + "Items in actual but not in expected:", + "[5, 6]", + ] + ) + self.assertFails(expected_error, self.assertEqual, a, b) + + def test_assertEqual_set_missing_items_only(self): + a = {1, 2, 3, 4} + b = {2, 3} + expected_error = "\n".join( + [ + "!=:", + "Items in expected but not in actual:", + "[1, 4]", + ] + ) + self.assertFails(expected_error, self.assertEqual, a, b) + + def test_assertEqual_set_extra_items_only(self): + a = {1, 2} + b = {1, 2, 3, 4} + expected_error = "\n".join( + [ + "!=:", + "Items in actual but not in expected:", + "[3, 4]", + ] + ) + self.assertFails(expected_error, self.assertEqual, a, b) + def test_assertIsNone(self): self.assertIsNone(None)