11# pylint: disable=line-too-long, missing-function-docstring 
2- from  typing  import  List , Optional 
3- 
42from  pytest  import  approx 
53import  numpy  as  np 
64import  pandas  as  pd 
75
86from  sklearn .linear_model  import  LogisticRegression 
7+ from  sklearn .neighbors  import  NearestNeighbors 
98from  sklearn .preprocessing  import  LabelEncoder 
109
1110from  aif360 .sklearn .metrics  import  (
1211    disparate_impact_ratio ,
1312    statistical_parity_difference ,
1413    average_odds_difference ,
1514    average_predictive_value_difference ,
15+     consistency_score 
1616)
1717
1818from  src .core .metrics .fairness .group .disparate_impact_ratio  import  DisparateImpactRatio 
1919from  src .core .metrics .fairness .group .group_average_odds_difference  import  GroupAverageOddsDifference 
2020from  src .core .metrics .fairness .group .group_average_predictive_value_difference  import  GroupAveragePredictiveValueDifference 
2121from  src .core .metrics .fairness .group .group_statistical_parity_difference  import  GroupStatisticalParityDifference 
22+ from  src .core .metrics .fairness .individual .individual_consistency  import  IndividualConsistency 
2223
2324df  =  pd .read_csv (
24-     "https://raw.githubusercontent.com/trustyai-explainability/model-collection/8aa8e2e762c6d2b41dbcbe8a0035d50aa5f58c93/bank-churn/ data/train .csv" ,
25+     "tests/ data/bank_churn_train .csv" ,
2526)
2627X  =  df .drop (columns = ["Exited" ], axis = 1 )
2728y  =  df ["Exited" ]
@@ -34,8 +35,7 @@ def train_model():
3435        X [feature ] =  label_encoders [feature ].fit_transform (X [feature ])
3536    lr  =  LogisticRegression ().fit (X , y )
3637
37-     y_pred  =  pd .DataFrame (lr .predict (X ))
38-     return  y_pred 
38+     return  pd .DataFrame (lr .predict (X ))
3939
4040def  truth_predict_output ():
4141    y .index  =  X ["Gender" ]
@@ -58,20 +58,92 @@ def get_labeled_data():
5858    data_pred [:, - 1 ] =  y_pred .to_numpy ().flatten ()
5959    return  data , data_pred 
6060
61+ 
62+ def  get_k_neighbors_function (k_value = 5 ):
63+     """Create a function that returns k nearest neighbors for a given input.""" 
64+ 
65+     def  find_neighbors (sample , samples ):
66+         """Find k nearest neighbors for a given sample.""" 
67+         if  isinstance (sample , np .ndarray ) and  sample .ndim  >  1 :
68+             sample  =  sample .flatten ()
69+ 
70+         nbrs  =  NearestNeighbors (n_neighbors = k_value  +  1 , algorithm = 'ball_tree' ).fit (samples )
71+         distances , indices  =  nbrs .kneighbors ([sample ])
72+ 
73+         neighbor_indices  =  indices [0 ][1 :k_value  +  1 ]
74+         return  samples [neighbor_indices ]
75+ 
76+     return  find_neighbors 
77+ 
78+ 
79+ def  get_processed_data (sample_size = None ):
80+     """Process data for testing individual consistency.""" 
81+     categorical_features  =  ['Geography' , 'Gender' , 'Card Type' , 'HasCrCard' , 'IsActiveMember' , 'Complain' ]
82+     X_processed  =  X .copy ()
83+     for  feature  in  categorical_features :
84+         if  feature  in  X_processed .columns :
85+             le  =  LabelEncoder ()
86+             X_processed [feature ] =  le .fit_transform (X_processed [feature ])
87+ 
88+     if  sample_size  is  not None :
89+         return  X_processed .to_numpy ()[:sample_size ]
90+     return  X_processed .to_numpy ()
91+ 
92+ 
93+ class  MockPredictionProvider :
94+     """Mock prediction provider for testing.""" 
95+ 
96+     def  __init__ (self , predictions ):
97+         self .predictions  =  predictions 
98+ 
99+     def  predict (self , x ):
100+         """Return prediction for input.""" 
101+         if  isinstance (x , np .ndarray ) and  x .ndim  ==  1 :
102+             x  =  x .reshape (1 , - 1 )
103+ 
104+         result  =  []
105+         for  i  in  range (x .shape [0 ]):
106+             if  i  <  len (self .predictions ):
107+                 result .append ([self .predictions [i ][0 ]])
108+             else :
109+                 result .append ([0 ])
110+         return  result 
111+ 
112+ 
113+ class  PerfectConsistencyProvider :
114+     """Provider that always returns the same prediction.""" 
115+ 
116+     def  predict (self , x ):
117+         if  isinstance (x , np .ndarray ) and  x .ndim  ==  1 :
118+             x  =  x .reshape (1 , - 1 )
119+         return  [[1 ] for  _  in  range (x .shape [0 ])]
120+ 
121+ 
122+ class  RandomPredictionProvider :
123+     """Provider that returns random predictions.""" 
124+ 
125+     def  __init__ (self , seed = 42 ):
126+         self .rng  =  np .random .RandomState (seed )
127+ 
128+     def  predict (self , x ):
129+         if  isinstance (x , np .ndarray ) and  x .ndim  ==  1 :
130+             x  =  x .reshape (1 , - 1 )
131+         return  [[self .rng .randint (0 , 2 )] for  _  in  range (x .shape [0 ])]
132+ 
61133y , y_pred  =  truth_predict_output ()
62134privileged , unprivileged  =  get_privileged_unprivleged_split ()
63135data , data_pred  =  get_labeled_data ()
64136
65137
66138def  test_disparate_impact_ratio ():
67-     dir  =  disparate_impact_ratio (y , prot_attr = "Gender" , priv_group = "Male" , pos_label = 1 )
139+     dir_result  =  disparate_impact_ratio (y , prot_attr = "Gender" , priv_group = "Male" , pos_label = 1 )
68140
69141    score  =  DisparateImpactRatio .calculate (
70142        privileged = privileged ,
71143        unprivileged = unprivileged ,
72144        favorable_output = 1 
73145    )
74-     assert  score  ==  approx (dir , abs = 1e-5 )
146+     assert  score  ==  approx (dir_result , abs = 1e-5 )
75147
76148
77149def  test_statistical_parity_difference ():
@@ -98,7 +170,7 @@ def test_average_odds_difference():
98170        output_column = - 1 
99171    )
100172
101-     assert  score  ==  approx (aod , abs = 0.2 )
173+     assert  score  ==  approx (aod , abs = 1e-5 )
102174
103175
104176def  test_average_predictive_value_difference ():
@@ -114,3 +186,62 @@ def test_average_predictive_value_difference():
114186    )
115187
116188    assert  score  ==  approx (apvd , abs = 0.2 )
189+ 
190+ 
191+ def  test_individual_consistency ():
192+     """Test individual consistency calculation using AIF360's consistency_score as ground truth.""" 
193+     X_sample  =  get_processed_data (sample_size = 50 )
194+     y_pred_sample  =  y_pred .iloc [:50 ].to_numpy ()
195+ 
196+     k  =  5 
197+     cs_score  =  consistency_score (X_sample , y_pred_sample .flatten ())
198+ 
199+     prediction_provider  =  MockPredictionProvider (y_pred_sample )
200+     proximity_function  =  get_k_neighbors_function (k )
201+ 
202+     score  =  IndividualConsistency .calculate (
203+         proximity_function = proximity_function ,
204+         samples = X_sample ,
205+         prediction_provider = prediction_provider 
206+     )
207+ 
208+     assert  score  ==  approx (cs_score , abs = 0.2 )
209+ 
210+ 
211+ def  test_individual_consistency_perfect ():
212+     """Test individual consistency with a perfect consistency model.""" 
213+     X_sample  =  get_processed_data (sample_size = 20 )
214+ 
215+     perfect_predictions  =  np .ones (20 )
216+ 
217+     cs_score  =  consistency_score (X_sample , perfect_predictions )
218+ 
219+     proximity_function  =  get_k_neighbors_function (3 )
220+ 
221+     consistency  =  IndividualConsistency .calculate (
222+         proximity_function = proximity_function ,
223+         samples = X_sample ,
224+         prediction_provider = PerfectConsistencyProvider ()
225+     )
226+ 
227+     assert  consistency  ==  approx (cs_score , abs = 0.2 )
228+ 
229+ 
230+ def  test_individual_consistency_imperfect ():
231+     """Test individual consistency with an inconsistent model.""" 
232+     X_sample  =  get_processed_data (sample_size = 20 )
233+ 
234+     rng  =  np .random .RandomState (42 )
235+     random_predictions  =  rng .randint (0 , 2 , size = 20 )
236+ 
237+     cs_score  =  consistency_score (X_sample , random_predictions )
238+ 
239+     proximity_function  =  get_k_neighbors_function (3 )
240+ 
241+     consistency  =  IndividualConsistency .calculate (
242+         proximity_function = proximity_function ,
243+         samples = X_sample ,
244+         prediction_provider = RandomPredictionProvider (seed = 42 )
245+     )
246+ 
247+     assert  consistency  ==  approx (cs_score , abs = 0.2 )
0 commit comments