diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py index a12cf42a..76bbd941 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py @@ -316,6 +316,21 @@ class AttackInputData: # treated as multilabel data. force_multilabel_data: bool = False + # Directly comparing members of the type `np.ndarray` can throw an error: + # "The truth value of an array with more than one element is ambiguous." + # So we bring back an explicit implementation of __eq__ like it was prior to + # Python 3.13 in order work around this possibility. + def __eq__(self, other): + if self is other: + return True + if other.__class__ is self.__class__: + return tuple( + getattr(self, field.name) for field in dataclasses.fields(self) + ) == tuple( + getattr(other, field.name) for field in dataclasses.fields(other) + ) + return NotImplemented + @property def num_classes(self): if self.labels_train is None or self.labels_test is None: