Skip to content

Commit 8066b7d

Browse files
authored
Update _base.py
Made adjustments on _make_loss_rehline_param( ) function. Now this function supports giving rehline parameters for mean absolute error loss function.
1 parent 5f24620 commit 8066b7d

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

rehline/_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,14 @@ def _make_loss_rehline_param(loss, X, y):
358358
V[0] = -(y + loss['epsilon'])
359359
V[1] = (y - loss['epsilon'])
360360

361+
362+
elif (loss['name'] == 'MAE') \
363+
or (loss['name'] == 'mae') \
364+
or (loss['name'] == 'mean absolute error'):
365+
U = np.array([[1] * n, [-1] * n])
366+
V = np.array([-y , y])
367+
368+
361369
else:
362370
raise Exception("Sorry, ReHLine currently does not support this loss function, \
363371
but you can manually set ReHLine params to solve the problem via `ReHLine` class.")
@@ -525,4 +533,4 @@ def _make_penalty_rehline_param(self, penalty=None, X=None):
525533
# self.U = U_new
526534
# self.V = V_new
527535
# self.auto_shape()
528-
# return X_fake
536+
# return X_fake

0 commit comments

Comments
 (0)