@@ -115,12 +115,21 @@ def get_global_lipschitz(self, X, y):
115115 c = max (self .quantile , 1 - self .quantile ) / self .delta
116116 return c * norm (X , ord = 2 ) ** 2 / len (y )
117117
118+ def intercept_update_step (self , y , Xw ):
119+ n_samples = len (y )
120+ update = 0.0
121+ for i in range (n_samples ):
122+ residual = y [i ] - Xw [i ]
123+ update -= self ._grad_per_sample (residual )
124+ return update / n_samples
125+
118126
119127class SmoothQuantileRegressor (BaseEstimator , RegressorMixin ):
120128 """Quantile regression with progressive smoothing."""
121129
122130 def __init__ (self , quantile = 0.5 , alpha = 0.1 , delta_init = 1.0 , delta_final = 1e-3 ,
123- n_deltas = 10 , max_iter = 1000 , tol = 1e-4 , verbose = False , solver = "FISTA" ):
131+ n_deltas = 10 , max_iter = 1000 , tol = 1e-4 , verbose = False ,
132+ solver = "AndersonCD" , fit_intercept = True ):
124133 self .quantile = quantile
125134 self .alpha = alpha
126135 self .delta_init = delta_init
@@ -130,6 +139,7 @@ def __init__(self, quantile=0.5, alpha=0.1, delta_init=1.0, delta_final=1e-3,
130139 self .tol = tol
131140 self .verbose = verbose
132141 self .solver = solver
142+ self .fit_intercept = fit_intercept
133143
134144 def fit (self , X , y ):
135145 """Fit using progressive smoothing: delta_init --> delta_final."""
@@ -146,11 +156,18 @@ def fit(self, X, y):
146156 # Solver selection
147157 if isinstance (self .solver , str ):
148158 if self .solver == "FISTA" :
159+ if self .fit_intercept :
160+ import warnings
161+ warnings .warn (
162+ "FISTA solver does not support intercept. "
163+ "Falling back to fit_intercept=False."
164+ )
165+ self .fit_intercept = False
149166 solver = FISTA (max_iter = self .max_iter , tol = self .tol )
150167 solver .warm_start = True
151168 elif self .solver == "AndersonCD" :
152169 solver = AndersonCD (max_iter = self .max_iter , tol = self .tol ,
153- warm_start = True , fit_intercept = False )
170+ warm_start = True , fit_intercept = self . fit_intercept )
154171 else :
155172 raise ValueError (f"Unknown solver: { self .solver } " )
156173 else :
@@ -167,6 +184,8 @@ def fit(self, X, y):
167184
168185 if self .verbose :
169186 residuals = y - X @ w
187+ if self .fit_intercept :
188+ residuals -= est .intercept_
170189 coverage = np .mean (residuals <= 0 )
171190 pinball_loss = np .mean (residuals * (self .quantile - (residuals < 0 )))
172191
0 commit comments