Skip to content

Commit 8971b36

Browse files
authored
Merge pull request #22 from Leona-LYT/main
delete _make_fair_classification function and modified related test file
2 parents a4bb32e + cb96a21 commit 8971b36

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

rehline/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from ._data import make_fair_classification
66
from ._internal import rehline_internal, rehline_result
77
from ._path_sol import plqERM_Ridge_path_sol
8+
from ._sklearn_mixin import plq_Ridge_Classifier, plq_Ridge_Regressor
89

910
__all__ = ("ReHLine_solver",
1011
"_BaseReHLine",
1112
"ReHLine",
1213
"plqERM_Ridge",
1314
"CQR_Ridge",
1415
"plqERM_Ridge_path_sol",
16+
"plq_Ridge_Classifier",
17+
"plq_Ridge_Regressor",
1518
"_make_loss_rehline_param",
16-
"_make_constraint_rehline_param"
17-
"make_fair_classification")
19+
"_make_constraint_rehline_param")

rehline/_sklearn_mixin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def fit(self, X, y, sample_weight=None):
410410
col = np.full((X.shape[0], 1), self.intercept_scaling, dtype=X.dtype)
411411
X_aug = np.hstack([X, col])
412412

413-
# Delegate to base solver (it will build loss/constraints from the X we pass)
414413
super().fit(X_aug, y, sample_weight=sample_weight)
415414

416415
# Split intercept from coefficients to match sklearn's linear model API

tests/_test_fairsvm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
## Test SVM on simulated dataset
22
import numpy as np
3-
4-
from rehline import make_fair_classification, plqERM_Ridge
3+
from sklearn.datasets import make_classification
4+
from sklearn.preprocessing import StandardScaler
5+
from rehline import plqERM_Ridge
56

67
np.random.seed(1024)
78
# simulate classification dataset
8-
X, y, X_sen = make_fair_classification()
9-
n, d = X.shape
10-
C = 0.5
9+
n, d, C = 100, 5, 0.5
10+
X, y = make_classification(n, d)
11+
y = 2*y - 1
12+
13+
scaler = StandardScaler()
14+
X = scaler.fit_transform(X)
15+
sen_idx = [0]
1116

1217
## solution provided by ReHLine
1318
# build-in hinge loss for svm
1419
clf = plqERM_Ridge(loss={'name': 'svm'}, C=C)
1520

1621
# specific the param of FairSVM
22+
X_sen = X[:,sen_idx]
1723
A = np.repeat([X_sen @ X], repeats=[2], axis=0) / n
1824
A[1] = -A[1]
1925
# suppose the fair tolerance is 0.01

0 commit comments

Comments
 (0)