Skip to content

Commit c85391d

Browse files
committed
fix _loss.py
1 parent 0b740d5 commit c85391d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

rehline/_loss.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __call__(self, x):
6868
----------
6969
x: {array-like} of shape (n_samples, )
7070
Training vector, where `n_samples` is the number of samples
71+
72+
For ERM question, the input of this function is np.dot(X, self.coef_) rather than a single X.
7173
"""
7274
if (self.L > 0) and (self.H > 0):
7375
assert self.relu_coef.shape[1] == self.rehu_coef.shape[1], "n_samples for `relu_coef` and `rehu_coef` should be the same shape!"
@@ -80,9 +82,9 @@ def __call__(self, x):
8082
ans = 0
8183
if len(self.relu_coef) > 0:
8284
relu_input = (self.relu_coef.T * x[:,np.newaxis]).T + self.relu_intercept
83-
ans += np.sum(relu(relu_input), 0).sum()
85+
ans += np.sum(_relu(relu_input), 0).sum()
8486
if len(self.rehu_coef) > 0:
8587
rehu_input = (self.rehu_coef.T * x[:,np.newaxis]).T + self.rehu_intercept
86-
ans += np.sum(rehu(rehu_input, cut=self.rehu_cut), 0).sum()
88+
ans += np.sum(_rehu(rehu_input, cut=self.rehu_cut), 0).sum()
8789

8890
return ans

0 commit comments

Comments
 (0)