Skip to content

Commit c92c119

Browse files
mathurinmQB3
andauthored
API remove is_classif attribute of GeneralizedLinearModel (#66)
Co-authored-by: QB3 <[email protected]>
1 parent a9d42c6 commit c92c119

File tree

5 files changed

+29
-32
lines changed

5 files changed

+29
-32
lines changed

doc/add.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
With skglm, you can solve any custom Generalized Linear Model with arbitrary smooth datafit and arbitrary proximable penalty, by defining two classes: a ``Penalty`` and a ``Datafit``.
66

7-
They can then be passed to a :class:`~skglm.GeneralizedLinearEstimator`, using ``is_classif`` to specify if the task is classification or regression.
7+
They can then be passed to a :class:`~skglm.GeneralizedLinearEstimator`.
88

99

1010
.. code-block:: python
1111
1212
clf = GeneralizedLinearEstimator(
13-
MyDatafit(), MyPenalty(), is_classif=True
13+
MyDatafit(),
14+
MyPenalty(),
1415
)
1516
1617

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#
9292
# This is also used if you do content translation via gettext catalogs.
9393
# Usually you set "language" from the command line for these cases.
94-
language = None
94+
language = 'en'
9595

9696
# There are two options for replacing |today|: either, you set today to some
9797
# non-false value, then it is used:

examples/plot_logreg_various_penalties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@
3535
alpha = 0.005
3636
gamma = 3.0
3737
l1_ratio = 0.3
38-
clf_enet = GeneralizedLinearEstimator(Logistic(), L1_plus_L2(alpha, l1_ratio),
39-
is_classif=True, verbose=0)
38+
clf_enet = GeneralizedLinearEstimator(
39+
Logistic(),
40+
L1_plus_L2(alpha, l1_ratio),
41+
)
4042
y_pred_enet = clf_enet.fit(X_train, y_train).predict(X_test)
4143
f1_score_enet = f1_score(y_test, y_pred_enet)
4244

43-
clf_mcp = GeneralizedLinearEstimator(Logistic(), MCPenalty(alpha, gamma),
44-
is_classif=True, verbose=0)
45+
clf_mcp = GeneralizedLinearEstimator(
46+
Logistic(),
47+
MCPenalty(alpha, gamma),
48+
)
4549
y_pred_mcp = clf_mcp.fit(X_train, y_train).predict(X_test)
4650
f1_score_mcp = f1_score(y_test, y_pred_mcp)
4751

skglm/estimators.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626

2727

2828
def _glm_fit(X, y, model, datafit, penalty):
29-
is_classif = False
30-
if isinstance(datafit, Logistic) or isinstance(datafit, QuadraticSVC):
31-
is_classif = True
29+
is_classif = isinstance(datafit, (Logistic, QuadraticSVC))
3230

3331
if is_classif:
3432
check_classification_targets(y)
@@ -185,10 +183,6 @@ class GeneralizedLinearEstimator(LinearModel):
185183
Penalty. If None, `penalty` is initialized as a `L1` penalty.
186184
`penalty` is replaced by a JIT-compiled instance when calling fit.
187185
188-
is_classif : bool, optional
189-
Whether the task is classification or regression. Used for input target
190-
validation.
191-
192186
max_iter : int, optional
193187
The maximum number of iterations (subproblem definitions).
194188
@@ -229,11 +223,10 @@ class GeneralizedLinearEstimator(LinearModel):
229223
Number of subproblems solved to reach the specified tolerance.
230224
"""
231225

232-
def __init__(self, datafit=None, penalty=None, is_classif=False, max_iter=100,
226+
def __init__(self, datafit=None, penalty=None, max_iter=100,
233227
max_epochs=50_000, p0=10, tol=1e-4, fit_intercept=True,
234228
warm_start=False, ws_strategy="subdiff", verbose=0):
235229
super(GeneralizedLinearEstimator, self).__init__()
236-
self.is_classif = is_classif
237230
self.tol = tol
238231
self.max_iter = max_iter
239232
self.fit_intercept = fit_intercept
@@ -254,9 +247,9 @@ def __repr__(self):
254247
String representation.
255248
"""
256249
return (
257-
'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s, classif=%s)'
250+
'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s)'
258251
% (self.datafit.__class__.__name__, self.penalty.__class__.__name__,
259-
self.penalty.alpha, self.is_classif))
252+
self.penalty.alpha))
260253

261254
def fit(self, X, y):
262255
"""Fit estimator.
@@ -300,7 +293,7 @@ def predict(self, X):
300293
y_pred : array, shape (n_samples)
301294
Contain the target values for each sample.
302295
"""
303-
if self.is_classif:
296+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
304297
scores = self._decision_function(X).ravel()
305298
if len(scores.shape) == 1:
306299
indices = (scores > 0).astype(int)

skglm/tests/test_estimators.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,26 +154,25 @@ def test_mtl_path():
154154

155155

156156
# Test if GeneralizedLinearEstimator returns the correct coefficients
157-
@pytest.mark.parametrize("Datafit, Penalty, is_classif, Estimator, pen_args", [
158-
(Quadratic, L1, False, Lasso, [alpha]),
159-
(Quadratic, WeightedL1, False, WeightedLasso,
157+
@pytest.mark.parametrize("Datafit, Penalty, Estimator, pen_args", [
158+
(Quadratic, L1, Lasso, [alpha]),
159+
(Quadratic, WeightedL1, WeightedLasso,
160160
[alpha, np.random.choice(3, n_features)]),
161-
(Quadratic, L1_plus_L2, False, ElasticNet, [alpha, 0.3]),
162-
(Quadratic, MCPenalty, False, MCPRegression, [alpha, 3]),
163-
(QuadraticSVC, IndicatorBox, True, LinearSVC, [alpha]),
164-
(Logistic, L1, True, SparseLogisticRegression, [alpha]),
161+
(Quadratic, L1_plus_L2, ElasticNet, [alpha, 0.3]),
162+
(Quadratic, MCPenalty, MCPRegression, [alpha, 3]),
163+
(QuadraticSVC, IndicatorBox, LinearSVC, [alpha]),
164+
(Logistic, L1, SparseLogisticRegression, [alpha]),
165165
])
166166
@pytest.mark.parametrize('fit_intercept', [True, False])
167-
def test_generic_estimator(
168-
fit_intercept, Datafit, Penalty, is_classif, Estimator, pen_args):
167+
def test_generic_estimator(fit_intercept, Datafit, Penalty, Estimator, pen_args):
169168
if isinstance(Datafit(), QuadraticSVC) and fit_intercept:
170169
pytest.xfail()
171170
elif Datafit == Logistic and fit_intercept:
172171
pytest.xfail("TODO support intercept in Logistic datafit")
173172
else:
174173
target = Y if Datafit == QuadraticMultiTask else y
175174
gle = GeneralizedLinearEstimator(
176-
Datafit(), Penalty(*pen_args), is_classif, tol=1e-10,
175+
Datafit(), Penalty(*pen_args), tol=1e-10,
177176
fit_intercept=fit_intercept).fit(X, target)
178177
est = Estimator(
179178
*pen_args, tol=1e-10, fit_intercept=fit_intercept).fit(X, target)
@@ -201,7 +200,7 @@ def test_estimator_predict(Datafit, Penalty, Estimator_sk):
201200
}
202201
X_test = np.random.normal(0, 1, (n_samples, n_features))
203202
clf = GeneralizedLinearEstimator(
204-
Datafit(), Penalty(1.), is_classif, fit_intercept=False, tol=tol).fit(X, y)
203+
Datafit(), Penalty(1.), fit_intercept=False, tol=tol).fit(X, y)
205204
clf_sk = Estimator_sk(**estim_args[Estimator_sk]).fit(X, y)
206205
y_pred = clf.predict(X_test)
207206
y_pred_sk = clf_sk.predict(X_test)
@@ -221,8 +220,8 @@ def assert_deep_dict_equal(expected_attr, estimator):
221220
else:
222221
assert v == v_est
223222

224-
reg = GeneralizedLinearEstimator(Quadratic(), L1(4.), is_classif=False)
225-
clf = GeneralizedLinearEstimator(Logistic(), MCPenalty(2., 3.), is_classif=True)
223+
reg = GeneralizedLinearEstimator(Quadratic(), L1(4.))
224+
clf = GeneralizedLinearEstimator(Logistic(), MCPenalty(2., 3.))
226225

227226
# Xty and lipschitz attributes are defined for jit compiled classes
228227
# hence they are not included in the test

0 commit comments

Comments
 (0)