@@ -410,7 +410,7 @@ def _make_constraint_rehline_param(constraint, X, y=None):
410410 Each dictionary must contain a 'name' key, which specifies the type of constraint.
411411 The following constraint types are supported:
412412 * 'nonnegative' or '>=0': A non-negativity constraint.
413- * 'fair' or 'fairness': A fairness constraint.
413+ * 'fair' or 'fairness': A fairness constraint using 'sen_idx' and 'tol_sen' .
414414 * 'custom': A custom constraint, where the user must provide the constraint matrix 'A' and vector 'b'.
415415
416416 X : array-like of shape (n_samples, n_features)
@@ -424,43 +424,39 @@ def _make_constraint_rehline_param(constraint, X, y=None):
424424 A : array-like of shape (n_constraints, n_features)
425425 The constraint matrix.
426426
427- b : array-like of shape (n_constraints,)
427+ B : array-like of shape (n_constraints,)
428428 The constraint vector.
429-
430- Notes
431- -----
432- This function iterates over the list of constraints and generates the constraint matrix 'A' and vector 'b' accordingly.
433- For 'nonnegative' and 'fair' constraints, the function generates the constraint parameters automatically.
434- For 'custom' constraints, the user must provide the constraint matrix 'A' and vector 'b' explicitly.
435429 """
436430
437431 n , d = X .shape
438432
439433 ## initialization
440- A = np .empty (shape = (0 , 0 ))
441- b = np .empty (shape = (0 ))
434+ A = np .empty (shape = (0 , 0 ))
435+ b = np .empty (shape = (0 ))
442436
443437 for constr_tmp in constraint :
444438 if (constr_tmp ['name' ] == 'nonnegative' ) or (constr_tmp ['name' ] == '>=0' ):
445439 A_tmp = np .identity (d )
446440 b_tmp = np .zeros (d )
441+
447442 elif (constr_tmp ['name' ] == 'fair' ) or (constr_tmp ['name' ] == 'fairness' ):
448- X_sen = constr_tmp ['X_sen' ]
443+ sen_idx = constr_tmp ['sen_idx' ] # list of indices
449444 tol_sen = constr_tmp ['tol_sen' ]
450445 tol_sen = np .array (tol_sen ).reshape (- 1 )
451446
452- assert len ( X_sen ) == len ( X ), "X and X_sen must have the same length"
453- X_sen = X_sen .reshape (n ,- 1 )
447+ X_sen = X [:, sen_idx ]
448+ X_sen = X_sen .reshape (n , - 1 )
454449
455450 assert X_sen .shape [1 ] == len (tol_sen ), "dim of X_sen and len of tol_sen must be equal"
456- d_sen = X_sen .shape [1 ]
457451
458452 A_tmp = np .repeat (X_sen .T @ X , repeats = [2 ], axis = 0 ) / n
459453 A_tmp [::2 ] = - A_tmp [::2 ]
460454 b_tmp = np .repeat (tol_sen , repeats = [2 ], axis = 0 )
455+
461456 elif (constr_tmp ['name' ] == 'custom' ):
462457 A_tmp = constr_tmp ['A' ]
463458 b_tmp = constr_tmp ['b' ]
459+
464460 else :
465461 raise Exception ("Sorry, ReHLine currently does not support this constraint, \
466462 but you can add it by manually setting A and b via {'name': 'custom', 'A': A, 'b': b}" )
@@ -470,6 +466,7 @@ def _make_constraint_rehline_param(constraint, X, y=None):
470466
471467 return A , b
472468
469+
473470def _make_penalty_rehline_param (self , penalty = None , X = None ):
474471 """The `_make_penalty_rehline_param` function generates penalty parameters for the ReHLine solver.
475472 """
0 commit comments