Skip to content

Commit b17f7bf

Browse files
authored
Update _base.py
push `_cast_sample_bias()` and `_cast_sample_weight()`, now these 2 methods can be called in the form of global function
1 parent a4bb32e commit b17f7bf

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

rehline/_base.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,143 @@ def _make_penalty_rehline_param(self, penalty=None, X=None):
473473
raise Exception("Sorry, `_make_penalty_rehline_param` feature is currently under development.")
474474

475475

476+
def _cast_sample_bias(U, V, Tau, S, T, sample_bias=None):
477+
"""Cast sample bias to ReHLine parameters by injecting bias into V and T.
478+
479+
This function modifies the ReHLine parameters to incorporate individual
480+
sample biases through linear transformations of the intercept parameters.
481+
482+
Parameters
483+
----------
484+
U : array-like of shape (L, n_samples)
485+
ReLU coefficient matrix.
486+
487+
V : array-like of shape (L, n_samples)
488+
ReLU intercept vector.
489+
490+
Tau : array-like of shape (H, n_samples)
491+
ReHU cutpoint matrix.
492+
493+
S : array-like of shape (H, n_samples)
494+
ReHU coefficient vector.
495+
496+
T : array-like of shape (H, n_samples)
497+
ReHU intercept vector.
498+
499+
sample_bias : array-like of shape (n_samples, 1)
500+
Individual sample bias vector. If None, parameters are returned unchanged.
501+
502+
Returns
503+
-------
504+
U_bias : array-like of shape (L, n_samples)
505+
Biased coefficient matrix, actually doesn't change
506+
507+
V_bias : array-like of shape (L, n_samples)
508+
Biased ReLU intercept vector: V + U * sample_bias
509+
510+
Tau_bias : array-like of shape (H, n_samples)
511+
Biased ReHU cutpoint matrix, actually doesn't change
512+
513+
S_bias : array-like of shape (H, n_samples)
514+
Biased ReHU coefficient vector, actually doesn't change
515+
516+
T_bias : array-like of shape (H, n_samples)
517+
Biased ReHU intercept vector: T + S * sample_bias
518+
519+
Notes
520+
-----
521+
The transformation applies the sample bias through:
522+
- V_bias = V + U ⊙ sample_bias
523+
- T_bias = T + S ⊙ sample_bias
524+
525+
where ⊙ denotes element-wise multiplication with broadcasting.
526+
"""
527+
if sample_bias is None:
528+
return U, V, Tau, S, T
529+
530+
else:
531+
sample_bias = sample_bias.reshape(1, -1)
532+
U_bias = U
533+
V_bias = V + (U * sample_bias if U.shape[0] > 0 else 0)
534+
Tau_bias = Tau
535+
S_bias = S
536+
T_bias = T + (S * sample_bias if S.shape[0] > 0 else 0)
537+
538+
return U_bias, V_bias, Tau_bias, S_bias, T_bias
539+
540+
541+
def _cast_sample_weight(U, V, Tau, S, T, C=1.0, sample_weight=None):
542+
"""Apply sample weights and regularization to ReHLine parameters.
543+
544+
Parameters
545+
----------
546+
U : array-like of shape (L, n_samples)
547+
ReLU coefficient matrix.
548+
549+
V : array-like of shape (L, n_samples)
550+
ReLU intercept vector.
551+
552+
Tau : array-like of shape (H, n_samples)
553+
ReHU cutpoint matrix.
554+
555+
S : array-like of shape (H, n_samples)
556+
ReHU coefficient vector.
557+
558+
T : array-like of shape (H, n_samples)
559+
ReHU intercept vector.
560+
561+
C : float, default=1.0
562+
Regularization parameter. The strength of the regularization is
563+
inversely proportional to C. Must be strictly positive.
564+
565+
sample_weight : array-like of shape (n_samples,), default=None
566+
Individual sample weight. If None, then samples are equally weighted.
567+
568+
Returns
569+
-------
570+
U_weight : array-like of shape (L, n_samples)
571+
Weighted ReLU coefficient matrix.
572+
573+
V_weight : array-like of shape (L, n_samples)
574+
Weighted ReLU intercept vector.
575+
576+
Tau_weight : array-like of shape (H, n_samples)
577+
Weighted ReHU cutpoint matrix.
578+
579+
S_weight : array-like of shape (H, n_samples)
580+
Weighted ReHU coefficient vector.
581+
582+
T_weight : array-like of shape (H, n_samples)
583+
Weighted ReHU intercept vector.
584+
585+
Notes
586+
-----
587+
This function casts the sample weight to the ReHLine parameters by multiplying
588+
the sample weight with the ReLU and ReHU parameters. If sample_weight is None,
589+
then the sample weight is set to the regularization parameter C.
590+
"""
591+
sample_weight = C * sample_weight
592+
593+
if U.shape[0] > 0:
594+
U_weight = U * sample_weight
595+
V_weight = V * sample_weight
596+
else:
597+
U_weight = U
598+
V_weight = V
599+
600+
if S.shape[0] > 0:
601+
sqrt_sample_weight = np.sqrt(sample_weight)
602+
Tau_weight = Tau * sqrt_sample_weight
603+
S_weight = S * sqrt_sample_weight
604+
T_weight = T * sqrt_sample_weight
605+
else:
606+
Tau_weight = Tau
607+
S_weight = S
608+
T_weight = T
609+
610+
return U_weight, V_weight, Tau_weight, S_weight, T_weight
611+
612+
476613
# def append_l1(self, X, l1_pen=1.0):
477614
# r"""
478615
# This function appends the l1 penalty to the ReHLine problem. The formulation becomes:

0 commit comments

Comments
 (0)