Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit ecd9ad0

Browse files
authored
Fix forgotten intercept in SGDRegressor (#147)
* fix forgotten intercept in SGDRegressor * update tests
1 parent aa7dd4c commit ecd9ad0

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

lightning/impl/sgd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(self, loss="hinge", penalty="l2",
150150
self.n_calls = n_calls
151151
self.verbose = verbose
152152
self.coef_ = None
153+
self.intercept_ = None
153154

154155
def _get_loss(self):
155156
if self.multiclass:
@@ -395,6 +396,7 @@ def predict(self, X):
395396
try:
396397
assert_all_finite(self.coef_)
397398
pred = safe_sparse_dot(X, self.coef_.T)
399+
pred += self.intercept_
398400
except ValueError:
399401
n_samples = X.shape[0]
400402
n_vectors = self.coef_.shape[0]

lightning/impl/tests/test_sgd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_regression_squared_loss():
106106

107107
reg.fit(X, y)
108108
pred = reg.predict(X)
109-
assert_almost_equal(np.mean((pred - y) ** 2), 4.913, 3)
109+
assert_almost_equal(np.mean((pred - y) ** 2), 4.749, 3)
110110

111111

112112
def test_regression_squared_loss_nn_l1():
@@ -119,7 +119,7 @@ def test_regression_squared_loss_nn_l1():
119119

120120
reg.fit(X, y)
121121
pred = reg.predict(X)
122-
assert_almost_equal(np.mean((pred - y) ** 2), 0.033, 3)
122+
assert_almost_equal(np.mean((pred - y) ** 2), 0.016, 3)
123123
assert_false((reg.coef_ < 0).any())
124124

125125

@@ -132,7 +132,7 @@ def test_regression_squared_loss_nn_l2():
132132

133133
reg.fit(X, y)
134134
pred = reg.predict(X)
135-
assert_almost_equal(np.mean((pred - y) ** 2), 0.033, 3)
135+
assert_almost_equal(np.mean((pred - y) ** 2), 0.016, 3)
136136
assert_almost_equal(reg.coef_.sum(), 2.131, 3)
137137
assert_false((reg.coef_ < 0).any())
138138

@@ -147,5 +147,5 @@ def test_regression_squared_loss_multiple_output():
147147
Y[:, 1] = y
148148
reg.fit(X, Y)
149149
pred = reg.predict(X)
150-
assert_almost_equal(np.mean((pred - Y) ** 2), 4.541, 3)
150+
assert_almost_equal(np.mean((pred - Y) ** 2), 4.397, 3)
151151

0 commit comments

Comments
 (0)