Skip to content

Commit 21f1459

Browse files
first try at simple quantile huber
1 parent 530c55a commit 21f1459

File tree

4 files changed

+278
-0
lines changed

4 files changed

+278
-0
lines changed

examples/plot_smooth_quantile.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
===========================================
3+
Smooth Quantile Regression Example
4+
===========================================
5+
6+
"""
7+
8+
import numpy as np
9+
import matplotlib.pyplot as plt
10+
import time
11+
from sklearn.datasets import make_regression
12+
from sklearn.preprocessing import StandardScaler
13+
from sklearn.linear_model import QuantileRegressor
14+
from skglm.experimental.smooth_quantile_regressor import SmoothQuantileRegressor
15+
from skglm.experimental.quantile_huber import QuantileHuber
16+
17+
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)
18+
X = StandardScaler().fit_transform(X)
19+
tau = 0.75
20+
21+
t0 = time.time()
22+
reg_skglm = SmoothQuantileRegressor(quantile=tau).fit(X, y)
23+
t1 = time.time()
24+
reg_sklearn = QuantileRegressor(quantile=tau, alpha=0.1, solver='highs').fit(X, y)
25+
t2 = time.time()
26+
27+
y_pred_skglm, y_pred_sklearn = reg_skglm.predict(X), reg_sklearn.predict(X)
28+
coverage_skglm = np.mean(y <= y_pred_skglm)
29+
coverage_sklearn = np.mean(y <= y_pred_sklearn)
30+
31+
print(f"\nTiming: skglm={t1-t0:.3f}s, sklearn={t2-t1:.3f}s, "
32+
f"speedup={(t2-t1)/(t1-t0):.1f}x")
33+
print(f"Coverage (target {tau}): skglm={coverage_skglm:.3f}, "
34+
f"sklearn={coverage_sklearn:.3f}")
35+
print(f"Non-zero coefs: skglm={np.sum(reg_skglm.coef_ != 0)}, "
36+
f"sklearn={np.sum(reg_sklearn.coef_ != 0)}")
37+
38+
39+
# Visualizations
40+
def pinball(y_true, y_pred):
41+
diff = y_true - y_pred
42+
return np.mean(np.where(diff >= 0, tau * diff, (1 - tau) * -diff))
43+
44+
45+
print(f"Pinball loss: skglm={pinball(y, y_pred_skglm):.4f}, "
46+
f"sklearn={pinball(y, y_pred_sklearn):.4f}")
47+
48+
plt.figure(figsize=(12, 5))
49+
plt.subplot(121)
50+
residuals = np.linspace(-2, 2, 1000)
51+
for delta in [1.0, 0.5, 0.1]:
52+
loss = QuantileHuber(quantile=tau, delta=delta)
53+
losses = [loss.value(np.array([r]), np.array([[1]]), np.array([0]))
54+
for r in residuals]
55+
plt.plot(residuals, losses, label=f'δ={delta}')
56+
plt.plot(residuals, [tau * max(r, 0) + (1 - tau) * max(-r, 0)
57+
for r in residuals], 'k--', label='Pinball')
58+
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
59+
plt.xlabel('Residual (y - y_pred)')
60+
plt.ylabel('Loss')
61+
plt.title('Quantile Huber Loss (τ=0.75)')
62+
plt.legend()
63+
plt.grid(True, alpha=0.3)
64+
65+
plt.subplot(122)
66+
plt.hist(y - y_pred_skglm, bins=50, alpha=0.5, label='skglm')
67+
plt.hist(y - y_pred_sklearn, bins=50, alpha=0.5, label='sklearn')
68+
plt.axvline(0, color='k', linestyle='--')
69+
plt.xlabel('Residual (y - y_pred)')
70+
plt.ylabel('Count')
71+
plt.title('Residuals Histogram')
72+
plt.legend()
73+
plt.grid(True, alpha=0.3)
74+
plt.tight_layout()
75+
plt.show()

skglm/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
from .sqrt_lasso import SqrtLasso, SqrtQuadratic
33
from .pdcd_ws import PDCD_WS
44
from .quantile_regression import Pinball
5+
from .quantile_huber import QuantileHuber
6+
from .smooth_quantile_regressor import SmoothQuantileRegressor
57

68
__all__ = [
79
IterativeReweightedL1,
810
PDCD_WS,
911
Pinball,
1012
SqrtQuadratic,
1113
SqrtLasso,
14+
QuantileHuber,
15+
SmoothQuantileRegressor,
1216
]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import numpy as np
2+
from numba import float64
3+
from skglm.datafits.single_task import Huber
4+
from skglm.utils.sparse_ops import spectral_norm
5+
6+
7+
class QuantileHuber(Huber):
8+
r"""Quantile Huber loss for quantile regression.
9+
10+
Implements the smoothed pinball loss with quadratic region:
11+
12+
.. math::
13+
14+
\rho_\tau^\delta(r) =
15+
\begin{cases}
16+
\tau\, r - \dfrac{\delta}{2}, & \text{if } r \ge \delta,\\
17+
\dfrac{\tau r^{2}}{2\delta}, & \text{if } 0 \le r < \delta,\\
18+
\dfrac{(1-\tau) r^{2}}{2\delta}, & \text{if } -\delta < r < 0,\\
19+
(\tau - 1)\, r - \dfrac{\delta}{2}, & \text{if } r \le -\delta.
20+
\end{cases}
21+
22+
Parameters
23+
----------
24+
quantile : float, default=0.5
25+
Desired quantile level between 0 and 1.
26+
delta : float, default=1.0
27+
Width of quadratic region.
28+
29+
References
30+
----------
31+
Chen, C. (2007). A Finite Smoothing Algorithm for Quantile Regression.
32+
Journal of Computational and Graphical Statistics, 16(1), 136–164.
33+
http://www.jstor.org/stable/27594233
34+
"""
35+
36+
def __init__(self, quantile=0.5, delta=1.0):
37+
if not 0 < quantile < 1:
38+
raise ValueError("quantile must be between 0 and 1")
39+
self.delta = float(delta)
40+
self.quantile = float(quantile)
41+
42+
def get_spec(self):
43+
return (('delta', float64), ('quantile', float64))
44+
45+
def params_to_dict(self):
46+
return dict(delta=self.delta, quantile=self.quantile)
47+
48+
def _loss_and_grad_scalar(self, residual):
49+
"""Calculate loss and gradient for a single residual."""
50+
tau = self.quantile
51+
delta = self.delta
52+
abs_r = abs(residual)
53+
54+
# Quadratic core: |r| ≤ delta
55+
if abs_r <= delta:
56+
if residual >= 0:
57+
# 0 ≤ r ≤ delta
58+
loss = tau * residual**2 / (2 * delta)
59+
grad = tau * residual / delta
60+
else:
61+
# -delta ≤ r < 0
62+
loss = (1 - tau) * residual**2 / (2 * delta)
63+
grad = (1 - tau) * residual / delta
64+
return loss, grad
65+
66+
# Linear tails: |r| > delta
67+
if residual > delta:
68+
loss = tau * (residual - delta / 2)
69+
grad = tau
70+
return loss, grad
71+
else:
72+
loss = (1 - tau) * (-residual - delta / 2)
73+
grad = tau - 1
74+
return loss, grad
75+
76+
def value(self, y, w, Xw):
77+
"""Compute the quantile Huber loss value."""
78+
residuals = y - Xw
79+
loss = np.zeros_like(residuals)
80+
for i, r in enumerate(residuals):
81+
loss[i], _ = self._loss_and_grad_scalar(r)
82+
return np.mean(loss)
83+
84+
def raw_grad(self, y, Xw):
85+
"""Compute gradient of datafit w.r.t Xw."""
86+
residuals = y - Xw
87+
grad = np.zeros_like(residuals)
88+
for i, r in enumerate(residuals):
89+
_, grad[i] = self._loss_and_grad_scalar(r)
90+
return -grad
91+
92+
def get_lipschitz(self, X, y):
93+
"""Compute coordinate-wise Lipschitz constants."""
94+
weight = max(self.quantile, 1 - self.quantile)
95+
return weight * (X ** 2).sum(axis=0) / (len(y) * self.delta)
96+
97+
def get_global_lipschitz(self, X, y):
98+
"""Compute global Lipschitz constant."""
99+
weight = max(self.quantile, 1 - self.quantile)
100+
return weight * np.linalg.norm(X, 2) ** 2 / (len(y) * self.delta)
101+
102+
def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
103+
"""Compute coordinate-wise Lipschitz constants for sparse X."""
104+
n_samples = len(y)
105+
weight = max(self.quantile, 1 - self.quantile)
106+
n_features = len(X_indptr) - 1
107+
lipschitz = np.zeros(n_features, dtype=X_data.dtype)
108+
for j in range(n_features):
109+
nrm2 = 0.0
110+
for idx in range(X_indptr[j], X_indptr[j + 1]):
111+
nrm2 += X_data[idx] ** 2
112+
lipschitz[j] = weight * nrm2 / (n_samples * self.delta)
113+
return lipschitz
114+
115+
def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
116+
"""Compute global Lipschitz constant for sparse X."""
117+
n_samples = len(y)
118+
weight = max(self.quantile, 1 - self.quantile)
119+
return weight * spectral_norm(
120+
X_data, X_indptr, X_indices, n_samples
121+
) ** 2 / (n_samples * self.delta)
122+
123+
def intercept_update_step(self, y, Xw):
124+
return -np.mean(self.raw_grad(y, Xw))
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
from sklearn.base import BaseEstimator, RegressorMixin
3+
from sklearn.utils.validation import check_X_y, check_array
4+
from ..solvers import FISTA
5+
from ..penalties import L1
6+
from ..estimators import GeneralizedLinearEstimator
7+
from .quantile_huber import QuantileHuber
8+
9+
10+
class SmoothQuantileRegressor(BaseEstimator, RegressorMixin):
11+
"""Quantile regression with progressive smoothing using Huberized loss."""
12+
13+
def __init__(self, quantile=0.75, alpha=1e-8, max_iter=1000, tol=1e-6,
14+
delta_init=1.0, delta_final=1e-4, n_deltas=10, fit_intercept=True):
15+
self.quantile = quantile
16+
self.alpha = alpha
17+
self.max_iter = max_iter
18+
self.tol = tol
19+
self.delta_init = delta_init
20+
self.delta_final = delta_final
21+
self.n_deltas = n_deltas
22+
self.fit_intercept = fit_intercept
23+
self.intercept_ = 0.0
24+
25+
def fit(self, X, y):
26+
"""Fit using FISTA with decreasing smoothing parameter delta.
27+
28+
For each delta level:
29+
- Update coefficients using FISTA
30+
- Update intercept using gradient step
31+
"""
32+
X, y = check_X_y(X, y)
33+
w = np.zeros(X.shape[1])
34+
intercept = np.quantile(y, self.quantile) if self.fit_intercept else 0.0
35+
36+
for delta in np.geomspace(self.delta_init, self.delta_final, self.n_deltas):
37+
datafit = QuantileHuber(quantile=self.quantile, delta=delta)
38+
est = GeneralizedLinearEstimator(
39+
datafit=datafit,
40+
penalty=L1(alpha=self.alpha),
41+
solver=FISTA(max_iter=self.max_iter, tol=self.tol)
42+
)
43+
est.coef_ = w
44+
est.fit(X, y)
45+
w = est.coef_
46+
47+
if self.fit_intercept:
48+
pred = X @ w + intercept
49+
lipschitz = datafit.get_global_lipschitz(X, y)
50+
grad = np.mean(datafit.raw_grad(y, pred))
51+
intercept -= grad / lipschitz
52+
53+
# Debug prints
54+
residuals = y - X.dot(w) - intercept
55+
obj_value = datafit.value(residuals, None, residuals) + \
56+
self.alpha * np.sum(np.abs(w))
57+
print(f"Delta: {delta:.6f}, Objective: {obj_value:.4f}, "
58+
f"Intercept: {intercept:.4f}, "
59+
f"Non-zero coefs: {np.sum(np.abs(w) > 1e-6)}, "
60+
f"Lipschitz: {lipschitz:.4f}")
61+
print(f"Residual stats - mean: {np.mean(residuals):.4f}, "
62+
f"std: {np.std(residuals):.4f}, "
63+
f"min: {np.min(residuals):.4f}, "
64+
f"max: {np.max(residuals):.4f}")
65+
66+
coverage = np.mean(y <= X.dot(w) + intercept)
67+
print(f"Coverage: {coverage:.4f} (target: {self.quantile:.4f})")
68+
69+
self.coef_, self.intercept_ = w, intercept
70+
return self
71+
72+
def predict(self, X):
73+
"""Predict using the fitted model."""
74+
check_array(X)
75+
return X @ self.coef_ + self.intercept_

0 commit comments

Comments
 (0)