Skip to content

Commit 0dfa771

Browse files
committed
modified _make_constraint_param
1 parent 6284ea5 commit 0dfa771

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

rehline/_base.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _make_constraint_rehline_param(constraint, X, y=None):
410410
Each dictionary must contain a 'name' key, which specifies the type of constraint.
411411
The following constraint types are supported:
412412
* 'nonnegative' or '>=0': A non-negativity constraint.
413-
* 'fair' or 'fairness': A fairness constraint.
413+
* 'fair' or 'fairness': A fairness constraint using 'sen_idx' and 'tol_sen'.
414414
* 'custom': A custom constraint, where the user must provide the constraint matrix 'A' and vector 'b'.
415415
416416
X : array-like of shape (n_samples, n_features)
@@ -424,43 +424,39 @@ def _make_constraint_rehline_param(constraint, X, y=None):
424424
A : array-like of shape (n_constraints, n_features)
425425
The constraint matrix.
426426
427-
b : array-like of shape (n_constraints,)
427+
B : array-like of shape (n_constraints,)
428428
The constraint vector.
429-
430-
Notes
431-
-----
432-
This function iterates over the list of constraints and generates the constraint matrix 'A' and vector 'b' accordingly.
433-
For 'nonnegative' and 'fair' constraints, the function generates the constraint parameters automatically.
434-
For 'custom' constraints, the user must provide the constraint matrix 'A' and vector 'b' explicitly.
435429
"""
436430

437431
n, d = X.shape
438432

439433
## initialization
440-
A=np.empty(shape=(0, 0))
441-
b=np.empty(shape=(0))
434+
A = np.empty(shape=(0, 0))
435+
b = np.empty(shape=(0))
442436

443437
for constr_tmp in constraint:
444438
if (constr_tmp['name'] == 'nonnegative') or (constr_tmp['name'] == '>=0'):
445439
A_tmp = np.identity(d)
446440
b_tmp = np.zeros(d)
441+
447442
elif (constr_tmp['name'] == 'fair') or (constr_tmp['name'] == 'fairness'):
448-
X_sen = constr_tmp['X_sen']
443+
sen_idx = constr_tmp['sen_idx'] # list of indices
449444
tol_sen = constr_tmp['tol_sen']
450445
tol_sen = np.array(tol_sen).reshape(-1)
451446

452-
assert len(X_sen) == len(X), "X and X_sen must have the same length"
453-
X_sen = X_sen.reshape(n,-1)
447+
X_sen = X[:, sen_idx]
448+
X_sen = X_sen.reshape(n, -1)
454449

455450
assert X_sen.shape[1] == len(tol_sen), "dim of X_sen and len of tol_sen must be equal"
456-
d_sen = X_sen.shape[1]
457451

458452
A_tmp = np.repeat(X_sen.T @ X, repeats=[2], axis=0) / n
459453
A_tmp[::2] = -A_tmp[::2]
460454
b_tmp = np.repeat(tol_sen, repeats=[2], axis=0)
455+
461456
elif (constr_tmp['name'] == 'custom'):
462457
A_tmp = constr_tmp['A']
463458
b_tmp = constr_tmp['b']
459+
464460
else:
465461
raise Exception("Sorry, ReHLine currently does not support this constraint, \
466462
but you can add it by manually setting A and b via {'name': 'custom', 'A': A, 'b': b}")
@@ -470,6 +466,7 @@ def _make_constraint_rehline_param(constraint, X, y=None):
470466

471467
return A, b
472468

469+
473470
def _make_penalty_rehline_param(self, penalty=None, X=None):
474471
"""The `_make_penalty_rehline_param` function generates penalty parameters for the ReHLine solver.
475472
"""

0 commit comments

Comments
 (0)