Skip to content

Commit f036743

Browse files
add fit_intercept support for LBFGS
1 parent ba5d9d9 commit f036743

File tree

2 files changed

+225
-15
lines changed

2 files changed

+225
-15
lines changed

issue320.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import numpy as np
2+
import time
3+
4+
from skglm import GeneralizedLinearEstimator
5+
from skglm.datafits import Poisson
6+
from skglm.penalties import L2
7+
from skglm.solvers import LBFGS
8+
from skglm.utils.data import make_correlated_data
9+
from sklearn.linear_model import PoissonRegressor
10+
from sklearn.metrics import mean_poisson_deviance, mean_absolute_error
11+
12+
13+
def generate_correlated_poisson_data(
14+
n_samples=20000,
15+
n_features=50,
16+
rho=0.5,
17+
density=0.5,
18+
seed=42
19+
):
20+
print("\n1. Generating synthetic correlated data for Poisson GLM...")
21+
print(
22+
f" (n_samples={n_samples}, n_features={n_features}, "
23+
f"rho={rho}, density={density})"
24+
)
25+
26+
# Use make_correlated_data to get X and w_true.
27+
X, _, w_true = make_correlated_data(
28+
n_samples=n_samples,
29+
n_features=n_features,
30+
rho=rho,
31+
snr=10,
32+
density=density,
33+
random_state=seed
34+
)
35+
36+
# Define a true intercept
37+
intercept_true = -1.0
38+
39+
# Calculate the linear predictor
40+
eta = intercept_true + X @ w_true
41+
42+
# Apply the inverse link function
43+
eta = np.clip(eta, -15, 15)
44+
mu = np.exp(eta)
45+
46+
# Generate the Poisson-distributed response variable
47+
rng = np.random.default_rng(seed)
48+
y = rng.poisson(mu)
49+
50+
return X, y, w_true, intercept_true
51+
52+
53+
def run_benchmark():
54+
"""Main function to run the GLM benchmark."""
55+
56+
# 1. Generate data
57+
# Parameters for data generation
58+
N_SAMPLES = 100000
59+
N_FEATURES = 1000
60+
RHO = 0.6
61+
DENSITY = 0.5 # Sparsity of true coefficients
62+
63+
X, y_true, w_true, intercept_true = generate_correlated_poisson_data(
64+
n_samples=N_SAMPLES,
65+
n_features=N_FEATURES,
66+
rho=RHO,
67+
density=DENSITY,
68+
seed=42
69+
)
70+
71+
# 2. Shared model parameters
72+
print("\n2. Setting up models...")
73+
alpha = 0.01 # L2 regularization strength
74+
tol = 1e-4 # Same tolerance as sklearn's PoissonRegressor
75+
iter_n = 1000 # Increase max_iter to allow convergence
76+
77+
# 3a. Fit the GLM with skglm
78+
print("\n3a. Fitting the GLM with skglm...")
79+
estimator_skglm = GeneralizedLinearEstimator(
80+
datafit=Poisson(),
81+
# Using L2 penalty (Ridge) for LBFGS compatibility
82+
penalty=L2(alpha=alpha),
83+
solver=LBFGS(verbose=False, tol=tol, max_iter=iter_n, fit_intercept=True)
84+
)
85+
86+
start_time_skglm = time.perf_counter()
87+
estimator_skglm.fit(X, y_true)
88+
end_time_skglm = time.perf_counter()
89+
skglm_fit_time = end_time_skglm - start_time_skglm
90+
print(f" skglm fit complete in {skglm_fit_time:.4f} seconds.")
91+
92+
# 3b. Fit the GLM with scikit-learn
93+
print("\n3b. Fitting the GLM with scikit-learn...")
94+
# PoissonRegressor in sklearn uses an L2 penalty.
95+
estimator_sklearn = PoissonRegressor(
96+
alpha=alpha,
97+
fit_intercept=True,
98+
tol=tol,
99+
solver='lbfgs',
100+
max_iter=iter_n
101+
)
102+
103+
start_time_sklearn = time.time()
104+
estimator_sklearn.fit(X, y_true)
105+
end_time_sklearn = time.time()
106+
sklearn_fit_time = end_time_sklearn - start_time_sklearn
107+
print(f" sklearn fit complete in {sklearn_fit_time:.4f} seconds.")
108+
109+
# 4. Compare the results
110+
print("\n" + "="*80)
111+
print("RESULTS COMPARISON")
112+
print("="*80)
113+
114+
# --- Coefficient Comparison ---
115+
print("\n--- Coefficient Comparison ---")
116+
117+
# Intercept
118+
print(f"{'Parameter':<20} | {'Ground Truth':<15} | "
119+
f"{'skglm Fit':<15} | {'sklearn Fit':<15}")
120+
print("-" * 75)
121+
print(f"{'Intercept':<20} | {intercept_true:<15.4f} | "
122+
f"{estimator_skglm.intercept_:<15.4f} | "
123+
f"{estimator_sklearn.intercept_:<15.4f}")
124+
125+
# MAE of Coefficients
126+
mae_skglm = mean_absolute_error(w_true, estimator_skglm.coef_)
127+
mae_sklearn = mean_absolute_error(w_true, estimator_sklearn.coef_)
128+
print(f"\n{'MAE vs. w_true':<20} | {'':<15} | "
129+
f"{mae_skglm:<15.6f} | {mae_sklearn:<15.6f}")
130+
131+
# Spot-check of first 5 coefficients
132+
print("\nSpot-check of first 5 coefficients:")
133+
print(f"{'Parameter':<12} | {'Ground Truth':<15} | "
134+
f"{'skglm Fit':<15} | {'sklearn Fit':<15}")
135+
print("-" * 65)
136+
for i in range(min(5, N_FEATURES)):
137+
print(
138+
f"w_{i:<10} | {w_true[i]:<15.4f} | "
139+
f"{estimator_skglm.coef_[i]:<15.4f} | "
140+
f"{estimator_sklearn.coef_[i]:<15.4f}")
141+
142+
# --- Timing Comparison ---
143+
print("\n--- Fitting Time Comparison ---")
144+
print(f"skglm (LBFGS): {skglm_fit_time:.4f} seconds")
145+
print(f"sklearn (L-BFGS): {sklearn_fit_time:.4f} seconds")
146+
if skglm_fit_time < sklearn_fit_time:
147+
speedup = sklearn_fit_time / \
148+
skglm_fit_time if skglm_fit_time > 0 else float('inf')
149+
print(f" >> skglm was {speedup:.2f}x faster.")
150+
else:
151+
speedup = skglm_fit_time / \
152+
sklearn_fit_time if sklearn_fit_time > 0 else float('inf')
153+
print(f" >> sklearn was {speedup:.2f}x faster.")
154+
155+
# --- Performance Metrics Comparison ---
156+
def calculate_metrics(estimator, X, y_true):
157+
y_pred = estimator.predict(X)
158+
# clip to avoid log(0) in deviance calculation
159+
y_pred = np.clip(y_pred, 1e-9, None)
160+
dev_model = len(y_true) * mean_poisson_deviance(y_true, y_pred)
161+
return dev_model
162+
163+
dev_model_skglm = calculate_metrics(estimator_skglm, X, y_true)
164+
dev_model_sklearn = calculate_metrics(estimator_sklearn, X, y_true)
165+
166+
# Null deviance
167+
y_null = np.full_like(y_true, fill_value=y_true.mean(), dtype=float)
168+
dev_null = len(y_true) * mean_poisson_deviance(y_true, y_null)
169+
170+
pseudo_r2_skglm = 1.0 - (dev_model_skglm / dev_null)
171+
pseudo_r2_sklearn = 1.0 - (dev_model_sklearn / dev_null)
172+
173+
print("\n--- Performance Metrics ---")
174+
print(f"{'Metric':<30} | {'skglm':<15} | {'sklearn':<15}")
175+
print("-" * 65)
176+
print(f"{'Model Deviance':<30} | {dev_model_skglm:<15,.2f} | "
177+
f"{dev_model_sklearn:<15,.2f}")
178+
print(f"{'Null Deviance':<30} | {dev_null:<15,.2f} | {dev_null:<15,.2f}")
179+
print(f"{'Pseudo R² (Deviance Explained)':<30} | "
180+
f"{pseudo_r2_skglm:<15.4f} | {pseudo_r2_sklearn:<15.4f}")
181+
182+
183+
if __name__ == "__main__":
184+
run_benchmark()

skglm/solvers/lbfgs.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@ class LBFGS(BaseSolver):
2424
tol : float, default 1e-4
2525
Tolerance for convergence.
2626
27+
fit_intercept : bool, default False
28+
Whether or not to fit an intercept.
29+
2730
verbose : bool, default False
2831
Amount of verbosity. 0/False is silent.
2932
"""
3033

3134
_datafit_required_attr = ("gradient",)
3235
_penalty_required_attr = ("gradient",)
3336

34-
def __init__(self, max_iter=50, tol=1e-4, verbose=False):
37+
def __init__(self, max_iter=50, tol=1e-4, fit_intercept=False, verbose=False):
3538
self.max_iter = max_iter
3639
self.tol = tol
40+
self.fit_intercept = fit_intercept
41+
self.warm_start = False
3742
self.verbose = verbose
3843

3944
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
@@ -46,25 +51,46 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
4651
datafit.initialize(X, y)
4752

4853
def objective(w):
49-
Xw = X @ w
50-
datafit_value = datafit.value(y, w, Xw)
51-
penalty_value = penalty.value(w)
54+
if self.fit_intercept:
55+
Xw = X @ w[:-1] + w[-1]
56+
datafit_value = datafit.value(y, w[:-1], Xw)
57+
penalty_value = penalty.value(w[:-1])
58+
else:
59+
Xw = X @ w
60+
datafit_value = datafit.value(y, w, Xw)
61+
penalty_value = penalty.value(w)
5262

5363
return datafit_value + penalty_value
5464

5565
def d_jac(w):
56-
Xw = X @ w
57-
datafit_grad = datafit.gradient(X, y, Xw)
58-
penalty_grad = penalty.gradient(w)
59-
60-
return datafit_grad + penalty_grad
66+
if self.fit_intercept:
67+
Xw = X @ w[:-1] + w[-1]
68+
datafit_grad = datafit.gradient(X, y, Xw)
69+
penalty_grad = penalty.gradient(w[:-1])
70+
intercept_grad = datafit.intercept_update_step(y, Xw)
71+
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
72+
else:
73+
Xw = X @ w
74+
datafit_grad = datafit.gradient(X, y, Xw)
75+
penalty_grad = penalty.gradient(w)
76+
77+
return datafit_grad + penalty_grad
6178

6279
def s_jac(w):
63-
Xw = X @ w
64-
datafit_grad = datafit.gradient_sparse(X.data, X.indptr, X.indices, y, Xw)
65-
penalty_grad = penalty.gradient(w)
66-
67-
return datafit_grad + penalty_grad
80+
if self.fit_intercept:
81+
Xw = X @ w[:-1] + w[-1]
82+
datafit_grad = datafit.gradient_sparse(
83+
X.data, X.indptr, X.indices, y, Xw)
84+
penalty_grad = penalty.gradient(w[:-1])
85+
intercept_grad = datafit.intercept_update_step(y, Xw)
86+
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
87+
else:
88+
Xw = X @ w
89+
datafit_grad = datafit.gradient_sparse(
90+
X.data, X.indptr, X.indices, y, Xw)
91+
penalty_grad = penalty.gradient(w)
92+
93+
return datafit_grad + penalty_grad
6894

6995
def callback_post_iter(w_k):
7096
# save p_obj
@@ -81,7 +107,7 @@ def callback_post_iter(w_k):
81107
)
82108

83109
n_features = X.shape[1]
84-
w = np.zeros(n_features) if w_init is None else w_init
110+
w = np.zeros(n_features + self.fit_intercept) if w_init is None else w_init
85111
jac = s_jac if issparse(X) else d_jac
86112
p_objs_out = []
87113

0 commit comments

Comments
 (0)