Skip to content

Commit 39ced0d

Browse files
add intercept method (only works for AndersonCD so far)
1 parent 938d842 commit 39ced0d

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

examples/plot_smooth_quantile.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,19 @@
88
import matplotlib.pyplot as plt
99
from sklearn.datasets import make_regression
1010

11-
# TODO: no intercept handling yet
12-
1311

1412
def pinball_loss(residuals, quantile):
1513
"""True pinball loss."""
1614
return np.mean(residuals * (quantile - (residuals < 0)))
1715

1816

19-
X, y = make_regression(n_samples=10000, n_features=1000, noise=0.1, random_state=0)
17+
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=0)
2018
tau = 0.8
21-
X_c = X - X.mean(axis=0)
22-
q_tau = np.quantile(y, tau)
23-
y_c = y - q_tau
2419

2520
start = time.time()
26-
sk = QuantileRegressor(quantile=tau, alpha=0.1, fit_intercept=False)
27-
sk.fit(X_c, y_c)
28-
sk_pred = sk.predict(X_c) + q_tau
21+
sk = QuantileRegressor(quantile=tau, alpha=0.1, fit_intercept=True)
22+
sk.fit(X, y)
23+
sk_pred = sk.predict(X)
2924
sk_time = time.time() - start
3025
sk_cov = np.mean(y <= sk_pred)
3126
sk_pinball = pinball_loss(y - sk_pred, tau)
@@ -38,11 +33,12 @@ def pinball_loss(residuals, quantile):
3833
delta_final=0.01,
3934
n_deltas=5,
4035
solver="AndersonCD",
41-
verbose=True
36+
verbose=True,
37+
fit_intercept=True
4238
)
43-
qh.fit(X_c, y_c)
39+
qh.fit(X, y)
4440
qh_time = time.time() - start
45-
qh_pred = qh.predict(X_c) + q_tau
41+
qh_pred = qh.predict(X)
4642
qh_cov = np.mean(y <= qh_pred)
4743
qh_pinball = pinball_loss(y - qh_pred, tau)
4844

skglm/experimental/quantile_huber.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,21 @@ def get_global_lipschitz(self, X, y):
115115
c = max(self.quantile, 1 - self.quantile) / self.delta
116116
return c * norm(X, ord=2) ** 2 / len(y)
117117

118+
def intercept_update_step(self, y, Xw):
119+
n_samples = len(y)
120+
update = 0.0
121+
for i in range(n_samples):
122+
residual = y[i] - Xw[i]
123+
update -= self._grad_per_sample(residual)
124+
return update / n_samples
125+
118126

119127
class SmoothQuantileRegressor(BaseEstimator, RegressorMixin):
120128
"""Quantile regression with progressive smoothing."""
121129

122130
def __init__(self, quantile=0.5, alpha=0.1, delta_init=1.0, delta_final=1e-3,
123-
n_deltas=10, max_iter=1000, tol=1e-4, verbose=False, solver="FISTA"):
131+
n_deltas=10, max_iter=1000, tol=1e-4, verbose=False,
132+
solver="AndersonCD", fit_intercept=True):
124133
self.quantile = quantile
125134
self.alpha = alpha
126135
self.delta_init = delta_init
@@ -130,6 +139,7 @@ def __init__(self, quantile=0.5, alpha=0.1, delta_init=1.0, delta_final=1e-3,
130139
self.tol = tol
131140
self.verbose = verbose
132141
self.solver = solver
142+
self.fit_intercept = fit_intercept
133143

134144
def fit(self, X, y):
135145
"""Fit using progressive smoothing: delta_init --> delta_final."""
@@ -146,11 +156,18 @@ def fit(self, X, y):
146156
# Solver selection
147157
if isinstance(self.solver, str):
148158
if self.solver == "FISTA":
159+
if self.fit_intercept:
160+
import warnings
161+
warnings.warn(
162+
"FISTA solver does not support intercept. "
163+
"Falling back to fit_intercept=False."
164+
)
165+
self.fit_intercept = False
149166
solver = FISTA(max_iter=self.max_iter, tol=self.tol)
150167
solver.warm_start = True
151168
elif self.solver == "AndersonCD":
152169
solver = AndersonCD(max_iter=self.max_iter, tol=self.tol,
153-
warm_start=True, fit_intercept=False)
170+
warm_start=True, fit_intercept=self.fit_intercept)
154171
else:
155172
raise ValueError(f"Unknown solver: {self.solver}")
156173
else:
@@ -167,6 +184,8 @@ def fit(self, X, y):
167184

168185
if self.verbose:
169186
residuals = y - X @ w
187+
if self.fit_intercept:
188+
residuals -= est.intercept_
170189
coverage = np.mean(residuals <= 0)
171190
pinball_loss = np.mean(residuals * (self.quantile - (residuals < 0)))
172191

0 commit comments

Comments
 (0)