@@ -59,10 +59,12 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
59
59
self .verbose = verbose
60
60
61
61
def solve (self , X , y , datafit , penalty , w_init = None , Xw_init = None ):
62
+ dtype = X .dtype
62
63
n_samples , n_features = X .shape
63
64
fit_intercept = self .fit_intercept
64
- w = np .zeros (n_features + fit_intercept ) if w_init is None else w_init
65
- Xw = np .zeros (n_samples ) if Xw_init is None else Xw_init
65
+
66
+ w = np .zeros (n_features + fit_intercept , dtype ) if w_init is None else w_init
67
+ Xw = np .zeros (n_samples , dtype ) if Xw_init is None else Xw_init
66
68
all_features = np .arange (n_features )
67
69
stop_crit = 0.
68
70
p_objs_out = []
@@ -181,16 +183,17 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
181
183
# Minimize quadratic approximation for delta_w = w - w_epoch:
182
184
# b.T @ X @ delta_w + \
183
185
# 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w)
186
+ dtype = X .dtype
184
187
raw_hess = datafit .raw_hessian (y , Xw_epoch )
185
188
186
- lipschitz = np .zeros (len (ws ))
189
+ lipschitz = np .zeros (len (ws ), dtype )
187
190
for idx , j in enumerate (ws ):
188
191
lipschitz [idx ] = raw_hess @ X [:, j ] ** 2
189
192
190
193
# for a less costly stopping criterion, we do not compute the exact gradient,
191
194
# but store each coordinate-wise gradient every time we update one coordinate
192
- past_grads = np .zeros (len (ws ))
193
- X_delta_w_ws = np .zeros (X .shape [0 ])
195
+ past_grads = np .zeros (len (ws ), dtype )
196
+ X_delta_w_ws = np .zeros (X .shape [0 ], dtype )
194
197
ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
195
198
w_ws = w_epoch [ws_intercept ]
196
199
@@ -243,17 +246,18 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
243
246
@njit
244
247
def _descent_direction_s (X_data , X_indptr , X_indices , y , w_epoch ,
245
248
Xw_epoch , fit_intercept , grad_ws , datafit , penalty , ws , tol ):
249
+ dtype = X_data .dtype
246
250
raw_hess = datafit .raw_hessian (y , Xw_epoch )
247
251
248
- lipschitz = np .zeros (len (ws ))
252
+ lipschitz = np .zeros (len (ws ), dtype )
249
253
for idx , j in enumerate (ws ):
250
254
# equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
251
255
lipschitz [idx ] = _sparse_squared_weighted_norm (
252
256
X_data , X_indptr , X_indices , j , raw_hess )
253
257
254
258
# see _descent_direction() comment
255
- past_grads = np .zeros (len (ws ))
256
- X_delta_w_ws = np .zeros (Xw_epoch .shape [0 ])
259
+ past_grads = np .zeros (len (ws ), dtype )
260
+ X_delta_w_ws = np .zeros (Xw_epoch .shape [0 ], dtype )
257
261
ws_intercept = np .append (ws , - 1 ) if fit_intercept else ws
258
262
w_ws = w_epoch [ws_intercept ]
259
263
@@ -329,7 +333,11 @@ def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w
329
333
grad_ws = _construct_grad (X , y , w [:n_features ], Xw , datafit , ws )
330
334
# TODO: could be improved by passing in w[ws]
331
335
stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
332
- stop_crit += step * grad_ws @ delta_w_ws [:len (ws )]
336
+
337
+ # it is mandatory to split the two operations, otherwise numba raises an error
338
+ # cf. https://github.com/numba/numba/issues/9025
339
+ dot = grad_ws @ delta_w_ws [:len (ws )]
340
+ stop_crit += step * dot
333
341
334
342
if fit_intercept :
335
343
stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
@@ -364,7 +372,11 @@ def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, fit_intercep
364
372
y , w [:n_features ], Xw , datafit , ws )
365
373
# TODO: could be improved by passing in w[ws]
366
374
stop_crit = penalty .value (w [:n_features ]) - old_penalty_val
367
- stop_crit += step * grad_ws .T @ delta_w_ws [:len (ws )]
375
+
376
+ # it is mandatory to split the two operations, otherwise numba raises an error
377
+ # cf. https://github.com/numba/numba/issues/9025
378
+ dot = grad_ws .T @ delta_w_ws [:len (ws )]
379
+ stop_crit += step * dot
368
380
369
381
if fit_intercept :
370
382
stop_crit += step * delta_w_ws [- 1 ] * np .sum (datafit .raw_grad (y , Xw ))
@@ -385,7 +397,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
385
397
# Compute grad of datafit restricted to ws. This function avoids
386
398
# recomputing raw_grad for every j, which is costly for logreg
387
399
raw_grad = datafit .raw_grad (y , Xw )
388
- grad = np .zeros (len (ws ))
400
+ grad = np .zeros (len (ws ), dtype = X . dtype )
389
401
for idx , j in enumerate (ws ):
390
402
grad [idx ] = X [:, j ] @ raw_grad
391
403
return grad
@@ -395,7 +407,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws):
395
407
def _construct_grad_sparse (X_data , X_indptr , X_indices , y , w , Xw , datafit , ws ):
396
408
# Compute grad of datafit restricted to ws in case X sparse
397
409
raw_grad = datafit .raw_grad (y , Xw )
398
- grad = np .zeros (len (ws ))
410
+ grad = np .zeros (len (ws ), dtype = X_data . dtype )
399
411
for idx , j in enumerate (ws ):
400
412
grad [idx ] = _sparse_xj_dot (X_data , X_indptr , X_indices , j , raw_grad )
401
413
return grad
0 commit comments