Skip to content

Commit 8aea611

Browse files
authored
DOC add group logistic regression example (#246)
1 parent 3762627 commit 8aea611

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
===================================
3+
Group Logistic regression in python
4+
===================================
5+
Scikit-learn is missing a Group Logistic regression estimator. We show how to implement
6+
one with ``skglm``.
7+
"""
8+
9+
# Author: Mathurin Massias
10+
11+
import numpy as np
12+
13+
from skglm import GeneralizedLinearEstimator
14+
from skglm.datafits import LogisticGroup
15+
from skglm.penalties import WeightedGroupL2
16+
from skglm.solvers import GroupProxNewton
17+
from skglm.utils.data import make_correlated_data, grp_converter
18+
19+
n_features = 30
20+
X, y, _ = make_correlated_data(
21+
n_samples=10, n_features=30, random_state=0)
22+
y = np.sign(y)
23+
24+
25+
# %%
26+
# Classifier creation: combination of penalty, datafit and solver.
27+
#
28+
grp_size = 3 # groups are made of groups of 3 consecutive features
29+
n_groups = n_features // grp_size
30+
grp_indices, grp_ptr = grp_converter(grp_size, n_features=n_features)
31+
alpha = 0.01
32+
weights = np.ones(n_groups)
33+
penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices)
34+
datafit = LogisticGroup(grp_ptr, grp_indices)
35+
solver = GroupProxNewton(verbose=2)
36+
37+
# %%
38+
# Train the model
39+
clf = GeneralizedLinearEstimator(datafit, penalty, solver)
40+
clf.fit(X, y)
41+
42+
# %%
43+
# Fit check that groups are either all 0 or all non zero
44+
print(clf.coef_.reshape(-1, grp_size))

0 commit comments

Comments
 (0)