@@ -473,6 +473,143 @@ def _make_penalty_rehline_param(self, penalty=None, X=None):
473473 raise Exception ("Sorry, `_make_penalty_rehline_param` feature is currently under development." )
474474
475475
476+ def _cast_sample_bias (U , V , Tau , S , T , sample_bias = None ):
477+ """Cast sample bias to ReHLine parameters by injecting bias into V and T.
478+
479+ This function modifies the ReHLine parameters to incorporate individual
480+ sample biases through linear transformations of the intercept parameters.
481+
482+ Parameters
483+ ----------
484+ U : array-like of shape (L, n_samples)
485+ ReLU coefficient matrix.
486+
487+ V : array-like of shape (L, n_samples)
488+ ReLU intercept vector.
489+
490+ Tau : array-like of shape (H, n_samples)
491+ ReHU cutpoint matrix.
492+
493+ S : array-like of shape (H, n_samples)
494+ ReHU coefficient vector.
495+
496+ T : array-like of shape (H, n_samples)
497+ ReHU intercept vector.
498+
499+ sample_bias : array-like of shape (n_samples, 1)
500+ Individual sample bias vector. If None, parameters are returned unchanged.
501+
502+ Returns
503+ -------
504+ U_bias : array-like of shape (L, n_samples)
505+ Biased coefficient matrix, actually doesn't change
506+
507+ V_bias : array-like of shape (L, n_samples)
508+ Biased ReLU intercept vector: V + U * sample_bias
509+
510+ Tau_bias : array-like of shape (H, n_samples)
511+ Biased ReHU cutpoint matrix, actually doesn't change
512+
513+ S_bias : array-like of shape (H, n_samples)
514+ Biased ReHU coefficient vector, actually doesn't change
515+
516+ T_bias : array-like of shape (H, n_samples)
517+ Biased ReHU intercept vector: T + S * sample_bias
518+
519+ Notes
520+ -----
521+ The transformation applies the sample bias through:
522+ - V_bias = V + U ⊙ sample_bias
523+ - T_bias = T + S ⊙ sample_bias
524+
525+ where ⊙ denotes element-wise multiplication with broadcasting.
526+ """
527+ if sample_bias is None :
528+ return U , V , Tau , S , T
529+
530+ else :
531+ sample_bias = sample_bias .reshape (1 , - 1 )
532+ U_bias = U
533+ V_bias = V + (U * sample_bias if U .shape [0 ] > 0 else 0 )
534+ Tau_bias = Tau
535+ S_bias = S
536+ T_bias = T + (S * sample_bias if S .shape [0 ] > 0 else 0 )
537+
538+ return U_bias , V_bias , Tau_bias , S_bias , T_bias
539+
540+
541+ def _cast_sample_weight (U , V , Tau , S , T , C = 1.0 , sample_weight = None ):
542+ """Apply sample weights and regularization to ReHLine parameters.
543+
544+ Parameters
545+ ----------
546+ U : array-like of shape (L, n_samples)
547+ ReLU coefficient matrix.
548+
549+ V : array-like of shape (L, n_samples)
550+ ReLU intercept vector.
551+
552+ Tau : array-like of shape (H, n_samples)
553+ ReHU cutpoint matrix.
554+
555+ S : array-like of shape (H, n_samples)
556+ ReHU coefficient vector.
557+
558+ T : array-like of shape (H, n_samples)
559+ ReHU intercept vector.
560+
561+ C : float, default=1.0
562+ Regularization parameter. The strength of the regularization is
563+ inversely proportional to C. Must be strictly positive.
564+
565+ sample_weight : array-like of shape (n_samples,), default=None
566+ Individual sample weight. If None, then samples are equally weighted.
567+
568+ Returns
569+ -------
570+ U_weight : array-like of shape (L, n_samples)
571+ Weighted ReLU coefficient matrix.
572+
573+ V_weight : array-like of shape (L, n_samples)
574+ Weighted ReLU intercept vector.
575+
576+ Tau_weight : array-like of shape (H, n_samples)
577+ Weighted ReHU cutpoint matrix.
578+
579+ S_weight : array-like of shape (H, n_samples)
580+ Weighted ReHU coefficient vector.
581+
582+ T_weight : array-like of shape (H, n_samples)
583+ Weighted ReHU intercept vector.
584+
585+ Notes
586+ -----
587+ This function casts the sample weight to the ReHLine parameters by multiplying
588+ the sample weight with the ReLU and ReHU parameters. If sample_weight is None,
589+ then the sample weight is set to the regularization parameter C.
590+ """
591+ sample_weight = C * sample_weight
592+
593+ if U .shape [0 ] > 0 :
594+ U_weight = U * sample_weight
595+ V_weight = V * sample_weight
596+ else :
597+ U_weight = U
598+ V_weight = V
599+
600+ if S .shape [0 ] > 0 :
601+ sqrt_sample_weight = np .sqrt (sample_weight )
602+ Tau_weight = Tau * sqrt_sample_weight
603+ S_weight = S * sqrt_sample_weight
604+ T_weight = T * sqrt_sample_weight
605+ else :
606+ Tau_weight = Tau
607+ S_weight = S
608+ T_weight = T
609+
610+ return U_weight , V_weight , Tau_weight , S_weight , T_weight
611+
612+
476613# def append_l1(self, X, l1_pen=1.0):
477614# r"""
478615# This function appends the l1 penalty to the ReHLine problem. The formulation becomes:
0 commit comments