1
+ """
2
+ ===================================
3
+ Group Logistic regression in python
4
+ ===================================
5
+ Scikit-learn is missing a Group Logistic regression estimator. We show how to implement
6
+ one with ``skglm``.
7
+ """
8
+
9
+ # Author: Mathurin Massias
10
+
11
+ import numpy as np
12
+
13
+ from skglm import GeneralizedLinearEstimator
14
+ from skglm .datafits import LogisticGroup
15
+ from skglm .penalties import WeightedGroupL2
16
+ from skglm .solvers import GroupProxNewton
17
+ from skglm .utils .data import make_correlated_data , grp_converter
18
+
19
+ n_features = 30
20
+ X , y , _ = make_correlated_data (
21
+ n_samples = 10 , n_features = 30 , random_state = 0 )
22
+ y = np .sign (y )
23
+
24
+
25
+ # %%
26
+ # Classifier creation: combination of penalty, datafit and solver.
27
+ #
28
+ grp_size = 3 # groups are made of groups of 3 consecutive features
29
+ n_groups = n_features // grp_size
30
+ grp_indices , grp_ptr = grp_converter (grp_size , n_features = n_features )
31
+ alpha = 0.01
32
+ weights = np .ones (n_groups )
33
+ penalty = WeightedGroupL2 (alpha , weights , grp_ptr , grp_indices )
34
+ datafit = LogisticGroup (grp_ptr , grp_indices )
35
+ solver = GroupProxNewton (verbose = 2 )
36
+
37
+ # %%
38
+ # Train the model
39
+ clf = GeneralizedLinearEstimator (datafit , penalty , solver )
40
+ clf .fit (X , y )
41
+
42
+ # %%
43
+ # Fit check that groups are either all 0 or all non zero
44
+ print (clf .coef_ .reshape (- 1 , grp_size ))
0 commit comments