Skip to content

Commit 2c96439

Browse files
committed
Overrides the get_params method to fix the bug
1 parent 4030550 commit 2c96439

File tree

7 files changed

+154
-118
lines changed

7 files changed

+154
-118
lines changed

rehline/_base.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,50 @@ def __init__(self, *, C=1.,
5454
S=np.empty(shape=(0,0)), T=np.empty(shape=(0,0)),
5555
A=np.empty(shape=(0,0)), b=np.empty(shape=(0))):
5656
self.C = C
57-
self.U = U
58-
self.V = V
59-
self.S = S
60-
self.T = T
61-
self.Tau = Tau
62-
self.A = A
63-
self.b = b
64-
self.L = U.shape[0]
65-
self.H = S.shape[0]
66-
self.K = A.shape[0]
57+
self._U = U
58+
self._V = V
59+
self._S = S
60+
self._T = T
61+
self._Tau = Tau
62+
self._A = A
63+
self._b = b
64+
self.L = self._U.shape[0]
65+
self.H = self._S.shape[0]
66+
self.K = self._A.shape[0]
67+
68+
def get_params(self, deep=True):
69+
"""Get parameters for this estimator.
70+
71+
Override the default get_params to exclude computation-only parameters.
72+
73+
Parameters
74+
----------
75+
deep : bool, default=True
76+
If True, will return the parameters for this estimator and
77+
contained subobjects that are estimators.
78+
79+
Returns
80+
-------
81+
params : dict
82+
Parameter names mapped to their values.
83+
"""
84+
out = dict()
85+
for key in self._get_param_names():
86+
if key not in ['U', 'V', 'S', 'T', 'Tau', 'A', 'b', 'Lambda', 'Gamma', 'xi']:
87+
value = getattr(self, key)
88+
if deep and hasattr(value, 'get_params') and not isinstance(value, type):
89+
deep_items = value.get_params().items()
90+
out.update((key + '__' + k, val) for k, val in deep_items)
91+
out[key] = value
92+
return out
6793

6894
def auto_shape(self):
6995
"""
7096
Automatically generate the shape of the parameters of the ReHLine loss function.
7197
"""
72-
self.L = self.U.shape[0]
73-
self.H = self.S.shape[0]
74-
self.K = self.A.shape[0]
98+
self.L = self._U.shape[0]
99+
self.H = self._S.shape[0]
100+
self.K = self._A.shape[0]
75101

76102
def cast_sample_weight(self, sample_weight=None):
77103
"""
@@ -111,21 +137,21 @@ def cast_sample_weight(self, sample_weight=None):
111137
sample_weight = self.C*sample_weight
112138

113139
if self.L > 0:
114-
U_weight = self.U * sample_weight
115-
V_weight = self.V * sample_weight
140+
U_weight = self._U * sample_weight
141+
V_weight = self._V * sample_weight
116142
else:
117-
U_weight = self.U
118-
V_weight = self.V
143+
U_weight = self._U
144+
V_weight = self._V
119145

120146
if self.H > 0:
121147
sqrt_sample_weight = np.sqrt(sample_weight)
122-
Tau_weight = self.Tau * sqrt_sample_weight
123-
S_weight = self.S * sqrt_sample_weight
124-
T_weight = self.T * sqrt_sample_weight
148+
Tau_weight = self._Tau * sqrt_sample_weight
149+
S_weight = self._S * sqrt_sample_weight
150+
T_weight = self._T * sqrt_sample_weight
125151
else:
126-
Tau_weight = self.Tau
127-
S_weight = self.S
128-
T_weight = self.T
152+
Tau_weight = self._Tau
153+
S_weight = self._S
154+
T_weight = self._T
129155

130156
return U_weight, V_weight, Tau_weight, S_weight, T_weight
131157

@@ -147,9 +173,9 @@ def call_ReLHLoss(self, score):
147173
relu_input = np.zeros((self.L, n))
148174
rehu_input = np.zeros((self.H, n))
149175
if self.L > 0:
150-
relu_input = (self.U.T * score[:,np.newaxis]).T + self.V
176+
relu_input = (self._U.T * score[:,np.newaxis]).T + self._V
151177
if self.H > 0:
152-
rehu_input = (self.S.T * score[:,np.newaxis]).T + self.T
178+
rehu_input = (self._S.T * score[:,np.newaxis]).T + self._T
153179
return np.sum(_relu(relu_input), 0) + np.sum(_rehu(rehu_input), 0)
154180

155181
@abstractmethod

0 commit comments

Comments
 (0)