diff --git a/examples/plot_group_logistic_regression.py b/examples/plot_group_logistic_regression.py index 5c1e8f71..75854324 100644 --- a/examples/plot_group_logistic_regression.py +++ b/examples/plot_group_logistic_regression.py @@ -16,6 +16,8 @@ from skglm.solvers import GroupProxNewton from skglm.utils.data import make_correlated_data, grp_converter +import matplotlib.pyplot as plt + n_features = 30 X, y, _ = make_correlated_data( n_samples=10, n_features=30, random_state=0) @@ -42,3 +44,23 @@ # %% # Fit check that groups are either all 0 or all non zero print(clf.coef_.reshape(-1, grp_size)) + +# %% +# Visualise group-level sparsity + +coef_by_group = clf.coef_.reshape(-1, grp_size) +group_norms = np.linalg.norm(coef_by_group, axis=1) + +plt.figure(figsize=(8, 4)) +plt.bar(np.arange(n_groups), group_norms) +plt.xlabel("Group index") +plt.ylabel("L2 norm of coefficients") +plt.title("Group Sparsity Pattern") +plt.tight_layout() +plt.show() + +# %% +# This plot shows the L2 norm of the coefficients for each group. +# Groups with a zero norm have been set inactive by the model, +# illustrating how Group Logistic Regression enforces sparsity at the group level. +# (Note: This example uses a tiny synthetic dataset, so the pattern has limited interpretability.)