@@ -163,50 +163,6 @@ def test_preprocess_data_does_not_modify_original(self):
163163 assert original_data ['prediction' ].tolist () == ['True' , 'False' ]
164164 assert original_data ['sensitive' ].tolist () == ['A' , 'B' ]
165165
166- @patch ('sdmetrics.single_table.equalized_odds.XGBClassifier' )
167- def test_train_classifier (self , mock_xgb_class ):
168- """Test _train_classifier trains and returns XGBoost classifier."""
169- mock_classifier = Mock ()
170- mock_xgb_class .return_value = mock_classifier
171-
172- train_data = pd .DataFrame ({
173- 'feature1' : [1 , 2 , 3 ],
174- 'feature2' : [4 , 5 , 6 ],
175- 'target' : [0 , 1 , 0 ],
176- })
177-
178- result = EqualizedOddsImprovement ._train_classifier (train_data , 'target' )
179-
180- # Check classifier was created with correct parameters
181- mock_xgb_class .assert_called_once_with (enable_categorical = True )
182-
183- # Check fit was called with correct data
184- expected_features = pd .DataFrame ({
185- 'feature1' : [1 , 2 , 3 ],
186- 'feature2' : [4 , 5 , 6 ],
187- })
188- expected_target = pd .Series ([0 , 1 , 0 ], name = 'target' )
189-
190- mock_classifier .fit .assert_called_once ()
191- call_args = mock_classifier .fit .call_args [0 ]
192- pd .testing .assert_frame_equal (call_args [0 ], expected_features )
193- pd .testing .assert_series_equal (call_args [1 ], expected_target )
194-
195- assert result == mock_classifier
196-
197- def test_train_classifier_does_not_modify_original (self ):
198- """Test _train_classifier doesn't modify the original training data."""
199- original_data = pd .DataFrame ({
200- 'feature1' : [1 , 2 , 3 ],
201- 'target' : [0 , 1 , 0 ],
202- })
203-
204- with patch ('sdmetrics.single_table.equalized_odds.XGBClassifier' ):
205- EqualizedOddsImprovement ._train_classifier (original_data , 'target' )
206-
207- # Original data should still have target column
208- assert 'target' in original_data .columns
209-
210166 def test_compute_prediction_counts_both_groups (self ):
211167 """Test _compute_prediction_counts with data for both sensitive groups."""
212168 predictions = np .array ([1 , 0 , 1 , 0 , 1 , 0 ])
0 commit comments