Skip to content

Commit 6f4f38b

Browse files
committed
test: add additional unit test for SPD and DIR
1 parent 17bc25c commit 6f4f38b

File tree

4 files changed

+2025
-1436
lines changed

4 files changed

+2025
-1436
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
"h5py>=3.13.0,<4",
1818
"scikit-learn",
1919
"aif360",
20+
"hypothesis>=6.136.2",
21+
"pytest>=8.4.1",
2022
]
2123

2224
[project.optional-dependencies]

src/core/metrics/fairness/group/group_statistical_parity_difference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
import numpy as np
55
from sklearn.base import ClassifierMixin
66

7+
78
class GroupStatisticalParityDifference:
89
"""
910
Calculate group statistical parity difference (SPD).
1011
"""
12+
1113
@staticmethod
1214
def calculate_model(
1315
samples: np.ndarray,
1416
model: ClassifierMixin,
1517
privilege_columns: List[int],
1618
privilege_values: List[int],
1719
favorable_output,
18-
) -> float:
20+
) -> float:
1921
"""
2022
Calculate group statistical parity difference (SPD) for model outputs.
2123
:param samples a NumPy array of inputs to be used for testing fairness
@@ -37,13 +39,13 @@ def calculate(
3739
privileged,
3840
unprivileged,
3941
favorable_output,
40-
) -> float:
42+
) -> float:
4143
"""
4244
Calculate statistical/demographic parity difference (SPD) when the labels are pre-calculated.
4345
:param priviledged numPy array with the privileged groups
4446
:param unpriviledged numPy array with the unpriviledged groups
4547
:param favorableOutput an output that is considered favorable / desirable
46-
return SPD, between 0 and 1
48+
return SPD, between -1 and 1
4749
"""
4850
probability_privileged = np.sum(privileged[:, -1] == favorable_output) / len(privileged)
4951
probability_unprivileged = np.sum(unprivileged[:, -1] == favorable_output) / len(unprivileged)

0 commit comments

Comments
 (0)