@@ -54,24 +54,50 @@ def __init__(self, *, C=1.,
5454 S = np .empty (shape = (0 ,0 )), T = np .empty (shape = (0 ,0 )),
5555 A = np .empty (shape = (0 ,0 )), b = np .empty (shape = (0 ))):
5656 self .C = C
57- self .U = U
58- self .V = V
59- self .S = S
60- self .T = T
61- self .Tau = Tau
62- self .A = A
63- self .b = b
64- self .L = U .shape [0 ]
65- self .H = S .shape [0 ]
66- self .K = A .shape [0 ]
57+ self ._U = U
58+ self ._V = V
59+ self ._S = S
60+ self ._T = T
61+ self ._Tau = Tau
62+ self ._A = A
63+ self ._b = b
64+ self .L = self ._U .shape [0 ]
65+ self .H = self ._S .shape [0 ]
66+ self .K = self ._A .shape [0 ]
67+
68+ def get_params (self , deep = True ):
69+ """Get parameters for this estimator.
70+
71+ Override the default get_params to exclude computation-only parameters.
72+
73+ Parameters
74+ ----------
75+ deep : bool, default=True
76+ If True, will return the parameters for this estimator and
77+ contained subobjects that are estimators.
78+
79+ Returns
80+ -------
81+ params : dict
82+ Parameter names mapped to their values.
83+ """
84+ out = dict ()
85+ for key in self ._get_param_names ():
86+ if key not in ['U' , 'V' , 'S' , 'T' , 'Tau' , 'A' , 'b' , 'Lambda' , 'Gamma' , 'xi' ]:
87+ value = getattr (self , key )
88+ if deep and hasattr (value , 'get_params' ) and not isinstance (value , type ):
89+ deep_items = value .get_params ().items ()
90+ out .update ((key + '__' + k , val ) for k , val in deep_items )
91+ out [key ] = value
92+ return out
6793
6894 def auto_shape (self ):
6995 """
7096 Automatically generate the shape of the parameters of the ReHLine loss function.
7197 """
72- self .L = self .U .shape [0 ]
73- self .H = self .S .shape [0 ]
74- self .K = self .A .shape [0 ]
98+ self .L = self ._U .shape [0 ]
99+ self .H = self ._S .shape [0 ]
100+ self .K = self ._A .shape [0 ]
75101
76102 def cast_sample_weight (self , sample_weight = None ):
77103 """
@@ -111,21 +137,21 @@ def cast_sample_weight(self, sample_weight=None):
111137 sample_weight = self .C * sample_weight
112138
113139 if self .L > 0 :
114- U_weight = self .U * sample_weight
115- V_weight = self .V * sample_weight
140+ U_weight = self ._U * sample_weight
141+ V_weight = self ._V * sample_weight
116142 else :
117- U_weight = self .U
118- V_weight = self .V
143+ U_weight = self ._U
144+ V_weight = self ._V
119145
120146 if self .H > 0 :
121147 sqrt_sample_weight = np .sqrt (sample_weight )
122- Tau_weight = self .Tau * sqrt_sample_weight
123- S_weight = self .S * sqrt_sample_weight
124- T_weight = self .T * sqrt_sample_weight
148+ Tau_weight = self ._Tau * sqrt_sample_weight
149+ S_weight = self ._S * sqrt_sample_weight
150+ T_weight = self ._T * sqrt_sample_weight
125151 else :
126- Tau_weight = self .Tau
127- S_weight = self .S
128- T_weight = self .T
152+ Tau_weight = self ._Tau
153+ S_weight = self ._S
154+ T_weight = self ._T
129155
130156 return U_weight , V_weight , Tau_weight , S_weight , T_weight
131157
@@ -147,9 +173,9 @@ def call_ReLHLoss(self, score):
147173 relu_input = np .zeros ((self .L , n ))
148174 rehu_input = np .zeros ((self .H , n ))
149175 if self .L > 0 :
150- relu_input = (self .U .T * score [:,np .newaxis ]).T + self .V
176+ relu_input = (self ._U .T * score [:,np .newaxis ]).T + self ._V
151177 if self .H > 0 :
152- rehu_input = (self .S .T * score [:,np .newaxis ]).T + self .T
178+ rehu_input = (self ._S .T * score [:,np .newaxis ]).T + self ._T
153179 return np .sum (_relu (relu_input ), 0 ) + np .sum (_rehu (rehu_input ), 0 )
154180
155181 @abstractmethod
0 commit comments