Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions examples/plot_smooth_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests so far:

  • τ ≠ 0.5 (e.g. 0.8): SmoothQuantileRegressor reduces loss by >50% vs QuantileRegressor.
  • Large n (≥10 000):SmoothQuantileRegressor is 1.3×–2× faster and more accurate.
  • Median τ=0.5 & n≈1 000: scikit-learn’s QuantileRegressor remains the best choice.

These are anecdotal results—your mileage may vary. Tune sequence and inner‐solver settings accordingly. Still room for improvement

===========================================
Fast Quantile Regression with Smoothing
===========================================
This example demonstrates how SmoothQuantileRegressor achieves faster convergence
than scikit-learn's QuantileRegressor while maintaining accuracy, particularly
for large datasets.
"""

# %%
# Data Generation
# --------------
# First, we generate synthetic data with a known quantile structure.

import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import QuantileRegressor
from skglm.experimental.smooth_quantile_regressor import SmoothQuantileRegressor
from skglm.solvers import FISTA

# Set random seed for reproducibility
np.random.seed(42)

# Generate dataset - using a more reasonable size for quick testing
n_samples, n_features = 10000, 10 # Match test file size
X, y = make_regression(n_samples=n_samples, n_features=n_features,
noise=0.1, random_state=42)
X = StandardScaler().fit_transform(X)
y = y - np.mean(y) # Center y like in test file

# %%
# Model Comparison
# ---------------
# We compare scikit-learn's QuantileRegressor with our SmoothQuantileRegressor
# on the 80th quantile.

tau = 0.5 # median (SmoothQuantileRegressor works much better for non-median quantiles)
alpha = 0.1


def pinball_loss(y_true, y_pred, tau=0.5):
"""Compute Pinball (quantile) loss."""
residuals = y_true - y_pred
return np.mean(np.where(residuals >= 0,
tau * residuals,
(1 - tau) * -residuals))


# scikit-learn's QuantileRegressor
start_time = time.time()
qr = QuantileRegressor(quantile=tau, alpha=alpha, fit_intercept=True,
solver="highs").fit(X, y)
qr_time = time.time() - start_time
y_pred_qr = qr.predict(X)
qr_loss = pinball_loss(y, y_pred_qr, tau=tau)

# SmoothQuantileRegressor
start_time = time.time()
solver = FISTA(max_iter=2000, tol=1e-8)
solver.fit_intercept = True
sqr = SmoothQuantileRegressor(
smoothing_sequence=[1.0, 0.5, 0.2, 0.1, 0.05], # Base sequence, will be extended
quantile=tau, alpha=alpha, verbose=True, # Enable verbose to see stages
smooth_solver=solver
).fit(X, y)
sqr_time = time.time() - start_time
y_pred_sqr = sqr.predict(X)
sqr_loss = pinball_loss(y, y_pred_sqr, tau=tau)

# %%
# Performance Analysis
# ------------------
# Let's analyze both the performance and solution quality of both methods.

speedup = qr_time / sqr_time
rel_gap = (sqr_loss - qr_loss) / qr_loss

print("\nPerformance Summary:")
print("scikit-learn QuantileRegressor:")
print(f" Time: {qr_time:.2f}s")
print(f" Loss: {qr_loss:.6f}")
print("SmoothQuantileRegressor:")
print(f" Time: {sqr_time:.2f}s")
print(f" Loss: {sqr_loss:.6f}")
print(f" Speedup: {speedup:.1f}x")
print(f" Relative gap: {rel_gap:.1%}")

# %%
# Visual Comparison
# ---------------
# We create visualizations to compare the predictions and residuals
# of both methods.

# Sort data for better visualization
sort_idx = np.argsort(y)
y_sorted = y[sort_idx]
qr_pred = y_pred_qr[sort_idx]
sqr_pred = y_pred_sqr[sort_idx]

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot predictions
ax1.scatter(y_sorted, qr_pred, alpha=0.5, label='scikit-learn', s=10)
ax1.scatter(y_sorted, sqr_pred, alpha=0.5, label='SmoothQuantile', s=10)
ax1.plot([y_sorted.min(), y_sorted.max()],
[y_sorted.min(), y_sorted.max()], 'k--', alpha=0.3)
ax1.set_xlabel('True values')
ax1.set_ylabel('Predicted values')
ax1.set_title(f'Predictions (τ={tau})')
ax1.legend()

# Plot residuals
qr_residuals = y_sorted - qr_pred
sqr_residuals = y_sorted - sqr_pred
ax2.hist(qr_residuals, bins=50, alpha=0.5, label='scikit-learn')
ax2.hist(sqr_residuals, bins=50, alpha=0.5, label='SmoothQuantile')
ax2.axvline(x=0, color='k', linestyle='--', alpha=0.3)
ax2.set_xlabel('Residuals')
ax2.set_ylabel('Count')
ax2.set_title('Residual Distribution')
ax2.legend()

plt.tight_layout()

# %%
# Progressive Smoothing Analysis
# ----------------------------
# Let's examine how the smoothing parameter affects the solution quality.

stages = sqr.stage_results_
deltas = [s['delta'] for s in stages]
errors = [s['quantile_error'] for s in stages]
losses = [s['obj_value'] for s in stages]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot quantile error progression
ax1.semilogx(deltas, errors, 'o-')
ax1.set_xlabel('Smoothing parameter (δ)')
ax1.set_ylabel('Quantile error')
ax1.set_title('Quantile Error vs Smoothing')
ax1.grid(True, alpha=0.3)

# Plot objective value progression
ax2.semilogx(deltas, losses, 'o-')
ax2.set_xlabel('Smoothing parameter (δ)')
ax2.set_ylabel('Objective value')
ax2.set_title('Objective Value vs Smoothing')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
2 changes: 2 additions & 0 deletions skglm/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .sqrt_lasso import SqrtLasso, SqrtQuadratic
from .pdcd_ws import PDCD_WS
from .quantile_regression import Pinball
from .quantile_huber import QuantileHuber

__all__ = [
IterativeReweightedL1,
PDCD_WS,
Pinball,
SqrtQuadratic,
SqrtLasso,
QuantileHuber,
]
223 changes: 223 additions & 0 deletions skglm/experimental/quantile_huber.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@floriankozikowski to avoid having too many files, this should be merged with smooth_quantile.py since they contain very related code

Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import numpy as np
from numpy.linalg import norm
from numba import float64
from skglm.datafits.base import BaseDatafit
from skglm.utils.sparse_ops import spectral_norm


class QuantileHuber(BaseDatafit):
r"""Smoothed approximation of the pinball loss for quantile regression.

This class implements a smoothed version of the pinball loss used in quantile
regression. The original pinball loss is:

.. math::
\rho_\tau(r) = \begin{cases}
\tau r & \text{if } r \geq 0 \\
(\tau - 1) r & \text{if } r < 0
\end{cases}

The smoothed version (Huberized pinball loss) is:

.. math::
\rho_\tau^\delta(r) = \begin{cases}
\tau r - \frac{\delta}{2} & \text{if } r \geq \delta \\
\frac{r^2}{2\delta} & \text{if } |r| < \delta \\
(\tau - 1) r - \frac{\delta}{2} & \text{if } r \leq -\delta
\end{cases}

where :math:`\delta` is the smoothing parameter. As :math:`\delta \to 0`,
the smoothed loss converges to the original pinball loss.

Parameters
----------
delta : float, default=1.0
Smoothing parameter. Smaller values make the approximation closer to
the original pinball loss but may lead to numerical instability.

quantile : float, default=0.5
Quantile level between 0 and 1. When 0.5, the loss is symmetric
(Huber loss). For other values, the loss is asymmetric.

Attributes
----------
delta : float
Current smoothing parameter.

quantile : float
Current quantile level.

Notes
-----
The smoothed loss is continuously differentiable everywhere, making it
suitable for gradient-based optimization methods. The gradient is:

.. math::
\nabla \rho_\tau^\delta(r) = \begin{cases}
\tau & \text{if } r \geq \delta \\
\frac{r}{\delta} & \text{if } |r| < \delta \\
\tau - 1 & \text{if } r \leq -\delta
\end{cases}

The Hessian is piecewise constant:

.. math::
\nabla^2 \rho_\tau^\delta(r) = \begin{cases}
0 & \text{if } |r| \geq \delta \\
\frac{1}{\delta} & \text{if } |r| < \delta
\end{cases}

Examples
--------
>>> from skglm.experimental.quantile_huber import QuantileHuber
>>> import numpy as np
>>> loss = QuantileHuber(delta=0.1, quantile=0.8)
>>> r = np.array([-1.0, 0.0, 1.0])
>>> print(loss.value(r)) # Compute loss values
>>> print(loss.gradient(r)) # Compute gradients
"""

def __init__(self, delta, quantile):
if not 0 < quantile < 1:
raise ValueError("quantile must be between 0 and 1")
if delta <= 0:
raise ValueError("delta must be positive")
self.delta = float(delta)
self.quantile = float(quantile)

def get_spec(self):
spec = (
('delta', float64),
('quantile', float64),
)
return spec

def params_to_dict(self):
return dict(delta=self.delta, quantile=self.quantile)

def get_lipschitz(self, X, y):
n_samples = len(y)
weight = max(self.quantile, 1 - self.quantile)

lipschitz = weight * (X ** 2).sum(axis=0) / (n_samples * self.delta)
return lipschitz

def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
n_samples = len(y)
n_features = len(X_indptr) - 1
weight = max(self.quantile, 1 - self.quantile)

lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
nrm2 = 0.0
for idx in range(X_indptr[j], X_indptr[j + 1]):
nrm2 += X_data[idx] ** 2
lipschitz[j] = weight * nrm2 / (n_samples * self.delta)
return lipschitz

def get_global_lipschitz(self, X, y):
n_samples = len(y)
weight = max(self.quantile, 1 - self.quantile)
return weight * norm(X, 2) ** 2 / (n_samples * self.delta)

def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
n_samples = len(y)
weight = max(self.quantile, 1 - self.quantile)
return (
weight
* spectral_norm(X_data, X_indptr, X_indices, n_samples) ** 2
/ (n_samples * self.delta)
)

def _loss_and_grad_scalar(self, residual):
tau, delta = self.quantile, self.delta
abs_r = abs(residual)

if abs_r <= delta:
# Quadratic region
if residual > 0:
return tau * residual**2 / (2 * delta), tau * residual / delta
else:
return ((1 - tau) * residual**2 / (2 * delta), (1 - tau)
* residual / delta
)

# Linear tails
if residual > delta:
return tau * (residual - delta/2), tau
else: # residual < -delta
return (1 - tau) * (-residual - delta/2), -(1 - tau)

def value(self, y, w, Xw):
n_samples = len(y)
res = 0.0
for i in range(n_samples):
loss_i, _ = self._loss_and_grad_scalar(y[i] - Xw[i])
res += loss_i
return res / n_samples

def _dr(self, residual):
"""Compute dl/dr for each residual."""
tau = self.quantile
delt = self.delta

# Pick tau for r >= 0, (1 - tau) for r < 0
scale = np.where(residual >= 0, tau, 1 - tau)

# Inside the quadratic zone: slope = scale * (r / delt)
# Outside: slope is ± scale, same sign as r
dr = np.where(
np.abs(residual) <= delt,
scale * (residual / delt),
np.sign(residual) * scale
)
return dr

def gradient_scalar(self, X, y, w, Xw, j):
r = y - Xw
dr = self._dr(r)
return - X[:, j].dot(dr) / len(y)

def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
r = y - Xw
dr = self._dr(r)
idx_start, idx_end = X_indptr[j], X_indptr[j + 1]
rows = X_indices[idx_start:idx_end]
vals = X_data[idx_start:idx_end]
return - np.dot(vals, dr[rows]) / len(y)

def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
n_features = len(X_indptr) - 1
n_samples = len(y)
grad = np.zeros(n_features, dtype=Xw.dtype)
for j in range(n_features):
g = 0.0
for idx in range(X_indptr[j], X_indptr[j + 1]):
i = X_indices[idx]
residual = y[i] - Xw[i]
_, grad_r = self._loss_and_grad_scalar(residual)
g += -X_data[idx] * grad_r
grad[j] = g / n_samples
return grad

def intercept_update_step(self, y, Xw):
n_samples = len(y)
update = 0.0
for i in range(n_samples):
residual = y[i] - Xw[i]
_, grad_r = self._loss_and_grad_scalar(residual)
update += -grad_r
return update / n_samples

def initialize(self, X, y):
pass

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
pass

def gradient(self, X, y, Xw):
n_samples, n_features = X.shape
grad = np.zeros(n_features)
for j in range(n_features):
grad[j] = self.gradient_scalar(X, y, None, Xw, j)
return grad
Loading
Loading