Skip to content

Commit f6f0875

Browse files
authored
FIX support of float32 in ProxNewton solver (#170)
1 parent 189d21e commit f6f0875

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

skglm/solvers/prox_newton.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
5959
self.verbose = verbose
6060

6161
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
62+
dtype = X.dtype
6263
n_samples, n_features = X.shape
6364
fit_intercept = self.fit_intercept
64-
w = np.zeros(n_features + fit_intercept) if w_init is None else w_init
65-
Xw = np.zeros(n_samples) if Xw_init is None else Xw_init
65+
66+
w = np.zeros(n_features + fit_intercept, dtype) if w_init is None else w_init
67+
Xw = np.zeros(n_samples, dtype) if Xw_init is None else Xw_init
6668
all_features = np.arange(n_features)
6769
stop_crit = 0.
6870
p_objs_out = []
@@ -181,16 +183,17 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
181183
# Minimize quadratic approximation for delta_w = w - w_epoch:
182184
# b.T @ X @ delta_w + \
183185
# 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w)
186+
dtype = X.dtype
184187
raw_hess = datafit.raw_hessian(y, Xw_epoch)
185188

186-
lipschitz = np.zeros(len(ws))
189+
lipschitz = np.zeros(len(ws), dtype)
187190
for idx, j in enumerate(ws):
188191
lipschitz[idx] = raw_hess @ X[:, j] ** 2
189192

190193
# for a less costly stopping criterion, we do not compute the exact gradient,
191194
# but store each coordinate-wise gradient every time we update one coordinate
192-
past_grads = np.zeros(len(ws))
193-
X_delta_w_ws = np.zeros(X.shape[0])
195+
past_grads = np.zeros(len(ws), dtype)
196+
X_delta_w_ws = np.zeros(X.shape[0], dtype)
194197
ws_intercept = np.append(ws, -1) if fit_intercept else ws
195198
w_ws = w_epoch[ws_intercept]
196199

@@ -243,17 +246,18 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
243246
@njit
244247
def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
245248
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol):
249+
dtype = X_data.dtype
246250
raw_hess = datafit.raw_hessian(y, Xw_epoch)
247251

248-
lipschitz = np.zeros(len(ws))
252+
lipschitz = np.zeros(len(ws), dtype)
249253
for idx, j in enumerate(ws):
250254
# equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
251255
lipschitz[idx] = _sparse_squared_weighted_norm(
252256
X_data, X_indptr, X_indices, j, raw_hess)
253257

254258
# see _descent_direction() comment
255-
past_grads = np.zeros(len(ws))
256-
X_delta_w_ws = np.zeros(Xw_epoch.shape[0])
259+
past_grads = np.zeros(len(ws), dtype)
260+
X_delta_w_ws = np.zeros(Xw_epoch.shape[0], dtype)
257261
ws_intercept = np.append(ws, -1) if fit_intercept else ws
258262
w_ws = w_epoch[ws_intercept]
259263

@@ -329,7 +333,11 @@ def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w
329333
grad_ws = _construct_grad(X, y, w[:n_features], Xw, datafit, ws)
330334
# TODO: could be improved by passing in w[ws]
331335
stop_crit = penalty.value(w[:n_features]) - old_penalty_val
332-
stop_crit += step * grad_ws @ delta_w_ws[:len(ws)]
336+
337+
# it is mandatory to split the two operations, otherwise numba raises an error
338+
# cf. https://github.com/numba/numba/issues/9025
339+
dot = grad_ws @ delta_w_ws[:len(ws)]
340+
stop_crit += step * dot
333341

334342
if fit_intercept:
335343
stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw))
@@ -364,7 +372,11 @@ def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, fit_intercep
364372
y, w[:n_features], Xw, datafit, ws)
365373
# TODO: could be improved by passing in w[ws]
366374
stop_crit = penalty.value(w[:n_features]) - old_penalty_val
367-
stop_crit += step * grad_ws.T @ delta_w_ws[:len(ws)]
375+
376+
# it is mandatory to split the two operations, otherwise numba raises an error
377+
# cf. https://github.com/numba/numba/issues/9025
378+
dot = grad_ws.T @ delta_w_ws[:len(ws)]
379+
stop_crit += step * dot
368380

369381
if fit_intercept:
370382
stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw))
@@ -385,7 +397,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
385397
# Compute grad of datafit restricted to ws. This function avoids
386398
# recomputing raw_grad for every j, which is costly for logreg
387399
raw_grad = datafit.raw_grad(y, Xw)
388-
grad = np.zeros(len(ws))
400+
grad = np.zeros(len(ws), dtype=X.dtype)
389401
for idx, j in enumerate(ws):
390402
grad[idx] = X[:, j] @ raw_grad
391403
return grad
@@ -395,7 +407,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
395407
def _construct_grad_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, ws):
396408
# Compute grad of datafit restricted to ws in case X sparse
397409
raw_grad = datafit.raw_grad(y, Xw)
398-
grad = np.zeros(len(ws))
410+
grad = np.zeros(len(ws), dtype=X_data.dtype)
399411
for idx, j in enumerate(ws):
400412
grad[idx] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad)
401413
return grad

skglm/tests/test_estimators.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,9 @@ def test_mtl_path():
168168
np.testing.assert_allclose(coef_ours, coef_sk, rtol=1e-5)
169169

170170

171-
@pytest.mark.parametrize("use_efron", [True, False])
172-
def test_CoxEstimator(use_efron):
171+
@pytest.mark.parametrize("use_efron, use_float_32",
172+
product([True, False], [True, False]))
173+
def test_CoxEstimator(use_efron, use_float_32):
173174
try:
174175
from lifelines import CoxPHFitter
175176
except ModuleNotFoundError:
@@ -184,7 +185,8 @@ def test_CoxEstimator(use_efron):
184185
random_state = 1265
185186

186187
tm, s, X = make_dummy_survival_data(n_samples, n_features, normalize=True,
187-
with_ties=use_efron, random_state=random_state)
188+
with_ties=use_efron, use_float_32=use_float_32,
189+
random_state=random_state)
188190

189191
# compute alpha_max
190192
B = (tm >= tm[:, None]).astype(X.dtype)
@@ -214,7 +216,7 @@ def test_CoxEstimator(use_efron):
214216
df, duration_col=0, event_col=1,
215217
fit_options={"max_steps": 10_000, "precision": 1e-12}
216218
)
217-
w_ll = estimator.params_.values
219+
w_ll = estimator.params_.values.astype(X.dtype)
218220

219221
p_obj_skglm = datafit.value((tm, s), w, X @ w) + penalty.value(w)
220222
p_obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
@@ -223,14 +225,16 @@ def test_CoxEstimator(use_efron):
223225
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, atol=1e-6)
224226

225227

226-
@pytest.mark.parametrize("use_efron", [True, False])
227-
def test_CoxEstimator_sparse(use_efron):
228+
@pytest.mark.parametrize("use_efron, use_float_32",
229+
product([True, False], [True, False]))
230+
def test_CoxEstimator_sparse(use_efron, use_float_32):
228231
reg = 1e-2
229232
n_samples, n_features = 100, 30
230233
X_density, random_state = 0.5, 1265
231234

232235
tm, s, X = make_dummy_survival_data(n_samples, n_features, X_density=X_density,
233-
with_ties=use_efron, random_state=random_state)
236+
use_float_32=use_float_32, with_ties=use_efron,
237+
random_state=random_state)
234238

235239
# compute alpha_max
236240
B = (tm >= tm[:, None]).astype(X.dtype)
@@ -373,4 +377,5 @@ def test_warm_start(estimator_name):
373377

374378

375379
if __name__ == "__main__":
380+
test_CoxEstimator(True, True)
376381
pass

skglm/utils/data.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def make_correlated_data(
124124
return X, Y, w_true
125125

126126

127-
def make_dummy_survival_data(n_samples, n_features, normalize=False,
128-
X_density=1., with_ties=False, random_state=None):
127+
def make_dummy_survival_data(n_samples, n_features, normalize=False, X_density=1.,
128+
with_ties=False, use_float_32=False, random_state=None):
129129
"""Generate a random dataset for survival analysis.
130130
131131
The design matrix ``X`` is generated according to standard normal, the vector of
@@ -152,6 +152,9 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False,
152152
Determine if the data contains tied observations: observations with the same
153153
occurrences times ``tm``.
154154
155+
use_float_32 : bool, default=False
156+
It ``True`` returns data with type ``float32``, otherwise, it is ``float64``.
157+
155158
random_state : int, default=None
156159
Determines random number generation for data generation.
157160
@@ -167,20 +170,21 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False,
167170
The matrix of predictors. If ``density < 1``, a CSC sparse matrix is returned.
168171
"""
169172
rng = np.random.RandomState(random_state)
173+
dtype = np.float64 if use_float_32 is False else np.float32
170174

171175
if X_density == 1.:
172-
X = rng.randn(n_samples, n_features).astype(float, order='F')
176+
X = rng.randn(n_samples, n_features).astype(dtype, order='F')
173177
else:
174178
X = scipy.sparse.rand(
175-
n_samples, n_features, density=X_density, format="csc", dtype=float)
179+
n_samples, n_features, density=X_density, format="csc", dtype=dtype)
176180

177181
if not with_ties:
178-
tm = rng.weibull(a=1, size=n_samples)
182+
tm = rng.weibull(a=1, size=n_samples).astype(dtype)
179183
else:
180-
unique_tm = rng.weibull(a=1, size=n_samples // 10 + 1)
184+
unique_tm = rng.weibull(a=1, size=n_samples // 10 + 1).astype(dtype)
181185
tm = rng.choice(unique_tm, size=n_samples)
182186

183-
s = rng.choice(2, size=n_samples).astype(float)
187+
s = rng.choice(2, size=n_samples).astype(dtype)
184188

185189
if normalize and X_density == 1.:
186190
X = StandardScaler().fit_transform(X)

0 commit comments

Comments
 (0)