|
4 | 4 | import numpy as np
|
5 | 5 | from scipy.sparse import issparse
|
6 | 6 | from scipy.special import expit
|
7 |
| -from skglm.solvers.prox_newton import ProxNewton |
| 7 | +from numbers import Integral, Real |
| 8 | +from skglm.solvers import ProxNewton, LBFGS |
8 | 9 |
|
9 |
| -from sklearn.utils.validation import check_is_fitted |
10 |
| -from sklearn.utils import check_array, check_consistent_length |
| 10 | +from sklearn.utils.validation import (check_is_fitted, check_array, |
| 11 | + check_consistent_length) |
11 | 12 | from sklearn.linear_model._base import (
|
12 | 13 | LinearModel, RegressorMixin,
|
13 | 14 | LinearClassifierMixin, SparseCoefMixin, BaseEstimator
|
14 | 15 | )
|
15 | 16 | from sklearn.utils.extmath import softmax
|
16 | 17 | from sklearn.preprocessing import LabelEncoder
|
| 18 | +from sklearn.utils._param_validation import Interval, StrOptions |
17 | 19 | from sklearn.multiclass import OneVsRestClassifier, check_classification_targets
|
18 | 20 |
|
19 | 21 | from skglm.utils.jit_compilation import compiled_clone
|
20 | 22 | from skglm.solvers import AndersonCD, MultiTaskBCD
|
21 |
| -from skglm.datafits import Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask |
22 |
| -from skglm.penalties import L1, WeightedL1, L1_plus_L2, MCPenalty, IndicatorBox, L2_1 |
| 23 | +from skglm.datafits import Cox, Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask |
| 24 | +from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, |
| 25 | + MCPenalty, IndicatorBox, L2_1) |
23 | 26 |
|
24 | 27 |
|
25 | 28 | def _glm_fit(X, y, model, datafit, penalty, solver):
|
@@ -1159,6 +1162,169 @@ def fit(self, X, y):
|
1159 | 1162 | # TODO add predict_proba for LinearSVC
|
1160 | 1163 |
|
1161 | 1164 |
|
| 1165 | +class CoxEstimator(LinearModel): |
| 1166 | + r"""Elastic Cox estimator with Efron and Breslow estimate. |
| 1167 | +
|
| 1168 | + Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>` |
| 1169 | + for details about the datafit expression. The data convention for the estimator is |
| 1170 | +
|
| 1171 | + - ``X`` the design matrix with ``n_features`` predictors |
| 1172 | + - ``y`` a two-column array where the first ``tm`` is of event time occurrences |
| 1173 | + and the second ``s`` is of censoring. |
| 1174 | +
|
| 1175 | + For L2-regularized Cox (``l1_ratio=0.``) :ref:`LBFGS <skglm.solvers.LBFGS>` |
| 1176 | + is the used solver, otherwise it is :ref:`ProxNewton <skglm.solvers.ProxNewton>`. |
| 1177 | +
|
| 1178 | + Parameters |
| 1179 | + ---------- |
| 1180 | + alpha : float, optional |
| 1181 | + Penalty strength. It must be strictly positive. |
| 1182 | +
|
| 1183 | + l1_ratio : float, default=0.5 |
| 1184 | + The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For |
| 1185 | + ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it |
| 1186 | + is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a |
| 1187 | + combination of L1 and L2. |
| 1188 | +
|
| 1189 | + method : {'efron', 'breslow'}, default='efron' |
| 1190 | + The estimate used for the Cox datafit. Use ``efron`` to |
| 1191 | + handle tied observations. |
| 1192 | +
|
| 1193 | + tol : float, optional |
| 1194 | + Stopping criterion for the optimization. |
| 1195 | +
|
| 1196 | + max_iter : int, optional |
| 1197 | + The maximum number of iterations to solve the problem. |
| 1198 | +
|
| 1199 | + verbose : bool or int |
| 1200 | + Amount of verbosity. |
| 1201 | +
|
| 1202 | + Attributes |
| 1203 | + ---------- |
| 1204 | + coef_ : array, shape (n_features,) |
| 1205 | + Parameter vector of Cox regression. |
| 1206 | +
|
| 1207 | + stop_crit_ : float |
| 1208 | + The value of the stopping criterion at convergence. |
| 1209 | + """ |
| 1210 | + |
| 1211 | + _parameter_constraints: dict = { |
| 1212 | + "alpha": [Interval(Real, 0, None, closed="neither")], |
| 1213 | + "l1_ratio": [Interval(Real, 0, 1, closed="both")], |
| 1214 | + "method": [StrOptions({"efron", "breslow"})], |
| 1215 | + "tol": [Interval(Real, 0, None, closed="left")], |
| 1216 | + "max_iter": [Interval(Integral, 1, None, closed="left")], |
| 1217 | + "verbose": ["boolean", Interval(Integral, 0, 2, closed="both")], |
| 1218 | + } |
| 1219 | + |
| 1220 | + def __init__(self, alpha=1., l1_ratio=0.7, method="efron", tol=1e-4, |
| 1221 | + max_iter=50, verbose=False): |
| 1222 | + self.alpha = alpha |
| 1223 | + self.l1_ratio = l1_ratio |
| 1224 | + self.method = method |
| 1225 | + self.tol = tol |
| 1226 | + self.max_iter = max_iter |
| 1227 | + self.verbose = verbose |
| 1228 | + |
| 1229 | + def fit(self, X, y): |
| 1230 | + """Fit Cox estimator. |
| 1231 | +
|
| 1232 | + Parameters |
| 1233 | + ---------- |
| 1234 | + X : array-like, shape (n_samples, n_features) |
| 1235 | + Design matrix. |
| 1236 | +
|
| 1237 | + y : array-like, shape (n_samples, 2) |
| 1238 | + Two-column array where the first is of event time occurrences |
| 1239 | + and the second is of censoring. If it is of dimension 1, it is |
| 1240 | + assumed to be the times vector and there no censoring. |
| 1241 | +
|
| 1242 | + Returns |
| 1243 | + ------- |
| 1244 | + self : |
| 1245 | + The fitted estimator. |
| 1246 | + """ |
| 1247 | + self._validate_params() |
| 1248 | + |
| 1249 | + # validate input data |
| 1250 | + X = check_array( |
| 1251 | + X, |
| 1252 | + accept_sparse="csc", |
| 1253 | + order="F", |
| 1254 | + dtype=[np.float64, np.float32], |
| 1255 | + input_name="X", |
| 1256 | + ) |
| 1257 | + if y is None: |
| 1258 | + # Needed to pass check estimator. Message error is |
| 1259 | + # copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \ |
| 1260 | + # 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \ |
| 1261 | + # estimator_checks.py#L3886 |
| 1262 | + raise ValueError("requires y to be passed, but the target y is None") |
| 1263 | + y = check_array( |
| 1264 | + y, |
| 1265 | + accept_sparse=False, |
| 1266 | + order="F", |
| 1267 | + dtype=X.dtype, |
| 1268 | + ensure_2d=False, |
| 1269 | + input_name="y", |
| 1270 | + ) |
| 1271 | + if y.ndim == 1: |
| 1272 | + warnings.warn( |
| 1273 | + f"{repr(self)} requires the vector of response `y` to have " |
| 1274 | + f"two columns. Got one column.\nAssuming that `y` " |
| 1275 | + "is the vector of times and there is no censoring." |
| 1276 | + ) |
| 1277 | + y = np.column_stack((y, np.ones_like(y))).astype(X.dtype, order="F") |
| 1278 | + elif y.shape[1] > 2: |
| 1279 | + raise ValueError( |
| 1280 | + f"{repr(self)} requires the vector of response `y` to have " |
| 1281 | + f"two columns. Got {y.shape[1]} columns." |
| 1282 | + ) |
| 1283 | + |
| 1284 | + check_consistent_length(X, y) |
| 1285 | + |
| 1286 | + # init datafit and penalty |
| 1287 | + datafit = Cox(self.method) |
| 1288 | + |
| 1289 | + if self.l1_ratio == 1.: |
| 1290 | + penalty = L1(self.alpha) |
| 1291 | + elif 0. < self.l1_ratio < 1.: |
| 1292 | + penalty = L1_plus_L2(self.alpha, self.l1_ratio) |
| 1293 | + else: |
| 1294 | + penalty = L2(self.alpha) |
| 1295 | + |
| 1296 | + # skglm internal: JIT compile classes |
| 1297 | + datafit = compiled_clone(datafit) |
| 1298 | + penalty = compiled_clone(penalty) |
| 1299 | + |
| 1300 | + # init solver |
| 1301 | + if self.l1_ratio == 0.: |
| 1302 | + solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose) |
| 1303 | + else: |
| 1304 | + solver = ProxNewton( |
| 1305 | + max_iter=self.max_iter, tol=self.tol, verbose=self.verbose, |
| 1306 | + fit_intercept=False, |
| 1307 | + ) |
| 1308 | + |
| 1309 | + # solve problem |
| 1310 | + if not issparse(X): |
| 1311 | + datafit.initialize(X, y) |
| 1312 | + else: |
| 1313 | + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) |
| 1314 | + |
| 1315 | + w, _, stop_crit = solver.solve(X, y, datafit, penalty) |
| 1316 | + |
| 1317 | + # save to attribute |
| 1318 | + self.coef_ = w |
| 1319 | + self.stop_crit_ = stop_crit |
| 1320 | + |
| 1321 | + self.intercept_ = 0. |
| 1322 | + self.n_features_in_ = X.shape[1] |
| 1323 | + self.feature_names_in_ = np.arange(X.shape[1]) |
| 1324 | + |
| 1325 | + return self |
| 1326 | + |
| 1327 | + |
1162 | 1328 | class MultiTaskLasso(LinearModel, RegressorMixin):
|
1163 | 1329 | r"""MultiTaskLasso estimator.
|
1164 | 1330 |
|
|
0 commit comments