Skip to content

Commit 189d21e

Browse files
authored
ENH More efficient B.dot and B.T.dot in Cox datafit (#168)
1 parent e7048b6 commit 189d21e

File tree

1 file changed

+79
-41
lines changed

1 file changed

+79
-41
lines changed

skglm/datafits/single_task.py

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -561,19 +561,24 @@ class Cox(BaseDatafit):
561561
562562
Attributes
563563
----------
564-
B : array-like, shape (n_samples, n_samples)
565-
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
566-
if ``tm[j] >= tm[i]`` and ``0`` otherwise. This matrix is initialized
567-
using the ``.initialize`` method.
564+
T_indices : array-like, shape (n_samples,)
565+
Indices of observations with the same occurrence times stacked horizontally as
566+
``[group_1, group_2, ...]`` in ascending order. It is initialized
567+
with the ``.initialize`` method (or ``initialize_sparse`` for sparse ``X``).
568568
569-
H_indices : array-like, shape (n_samples,)
570-
Indices of observations with the same occurrence times stacked horizontally
571-
as ``[group_1, group_2, ...]``. This array is initialized
572-
when calling ``.initialize`` method when ``use_efron=True``.
569+
T_indptr : array-like, (np.unique(tm) + 1,)
570+
Array where two consecutive elements delimit a group of
571+
observations having the same occurrence times.
573572
574-
H_indptr : array-like, (np.unique(tm) + 1,)
575-
Array where two consecutive elements delimits a group of observations
576-
having the same occurrence times.
573+
H_indices : array-like, shape (n_samples,)
574+
Indices of uncensored observations with the same occurrence times stacked
575+
horizontally as ``[group_1, group_2, ...]`` in ascending order.
576+
It is initialized when calling the ``.initialize`` method
577+
(or ``initialize_sparse`` for sparse ``X``) when ``use_efron=True``.
578+
579+
H_indptr : array-like, shape (np.unique(tm[s != 0]) + 1,)
580+
Array where two consecutive elements delimits a group of uncensored
581+
observations having the same occurrence time.
577582
"""
578583

579584
def __init__(self, use_efron=False):
@@ -582,9 +587,8 @@ def __init__(self, use_efron=False):
582587
def get_spec(self):
583588
return (
584589
('use_efron', bool_),
585-
('B', float64[:, ::1]),
586-
('H_indptr', int64[:]),
587-
('H_indices', int64[:]),
590+
('T_indptr', int64[:]), ('T_indices', int64[:]),
591+
('H_indptr', int64[:]), ('H_indices', int64[:]),
588592
)
589593

590594
def params_to_dict(self):
@@ -597,7 +601,7 @@ def value(self, y, w, Xw):
597601

598602
# compute inside log term
599603
exp_Xw = np.exp(Xw)
600-
B_exp_Xw = self.B @ exp_Xw
604+
B_exp_Xw = self._B_dot_vec(exp_Xw)
601605
if self.use_efron:
602606
B_exp_Xw -= self._A_dot_vec(exp_Xw)
603607

@@ -614,12 +618,12 @@ def raw_grad(self, y, Xw):
614618
n_samples = Xw.shape[0]
615619

616620
exp_Xw = np.exp(Xw)
617-
B_exp_Xw = self.B @ exp_Xw
621+
B_exp_Xw = self._B_dot_vec(exp_Xw)
618622
if self.use_efron:
619623
B_exp_Xw -= self._A_dot_vec(exp_Xw)
620624

621625
s_over_B_exp_Xw = s / B_exp_Xw
622-
out = -s + exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
626+
out = -s + exp_Xw * self._B_T_dot_vec(s_over_B_exp_Xw)
623627
if self.use_efron:
624628
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)
625629

@@ -635,12 +639,12 @@ def raw_hessian(self, y, Xw):
635639
n_samples = Xw.shape[0]
636640

637641
exp_Xw = np.exp(Xw)
638-
B_exp_Xw = self.B @ exp_Xw
642+
B_exp_Xw = self._B_dot_vec(exp_Xw)
639643
if self.use_efron:
640644
B_exp_Xw -= self._A_dot_vec(exp_Xw)
641645

642646
s_over_B_exp_Xw = s / B_exp_Xw
643-
out = exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
647+
out = exp_Xw * self._B_T_dot_vec(s_over_B_exp_Xw)
644648
if self.use_efron:
645649
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)
646650

@@ -654,38 +658,53 @@ def initialize(self, X, y):
654658
"""Initialize the datafit attributes."""
655659
tm, s = y
656660

657-
tm_as_col = tm.reshape((-1, 1))
658-
self.B = (tm >= tm_as_col).astype(X.dtype)
661+
self.T_indices = np.argsort(tm)
662+
self.T_indptr = self._get_indptr(tm, self.T_indices)
659663

660664
if self.use_efron:
661-
H_indices = np.argsort(tm)
662665
# filter out censored data
663-
H_indices = H_indices[s[H_indices] != 0]
664-
n_uncensored_samples = H_indices.shape[0]
665-
666-
# build H_indptr
667-
H_indptr = [0]
668-
count = 1
669-
for i in range(1, n_uncensored_samples):
670-
if tm[H_indices[i-1]] == tm[H_indices[i]]:
671-
count += 1
672-
else:
673-
H_indptr.append(count + H_indptr[-1])
674-
count = 1
675-
H_indptr.append(n_uncensored_samples)
676-
H_indptr = np.asarray(H_indptr, dtype=np.int64)
677-
678-
# save in instance
679-
self.H_indptr = H_indptr
680-
self.H_indices = H_indices
666+
self.H_indices = self.T_indices[s[self.T_indices] != 0]
667+
self.H_indptr = self._get_indptr(tm, self.H_indices)
681668

682669
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
683670
"""Initialize the datafit attributes in sparse dataset case."""
684-
# initialize_sparse and initialize have the same implementation
671+
# `initialize_sparse` and `initialize` have the same implementation
685672
# small hack to avoid repetitive code: pass in X_data as only its dtype is used
686673
self.initialize(X_data, y)
687674

675+
def _B_dot_vec(self, vec):
676+
# compute `B @ vec` in O(n) instead of O(n^2)
677+
out = np.zeros_like(vec)
678+
n_T = self.T_indptr.shape[0] - 1
679+
cum_sum = 0.
680+
681+
# reverse loop to avoid starting from cum_sum and subtracting vec coordinates
682+
# subtracting big numbers results in 'cancellation errors' and hence erroneous
683+
# results. Ref. J Nocedal, "Numerical optimization", page 615
684+
for idx in range(n_T - 1, -1, -1):
685+
current_T_idx = self.T_indices[self.T_indptr[idx]: self.T_indptr[idx+1]]
686+
687+
cum_sum += np.sum(vec[current_T_idx])
688+
out[current_T_idx] = cum_sum
689+
690+
return out
691+
692+
def _B_T_dot_vec(self, vec):
693+
# compute `B.T @ vec` in O(n) instead of O(n^2)
694+
out = np.zeros_like(vec)
695+
n_T = self.T_indptr.shape[0] - 1
696+
cum_sum = 0.
697+
698+
for idx in range(n_T):
699+
current_T_idx = self.T_indices[self.T_indptr[idx]: self.T_indptr[idx+1]]
700+
701+
cum_sum += np.sum(vec[current_T_idx])
702+
out[current_T_idx] = cum_sum
703+
704+
return out
705+
688706
def _A_dot_vec(self, vec):
707+
# compute `A @ vec` in O(n) instead of O(n^2)
689708
out = np.zeros_like(vec)
690709
n_H = self.H_indptr.shape[0] - 1
691710

@@ -700,6 +719,7 @@ def _A_dot_vec(self, vec):
700719
return out
701720

702721
def _AT_dot_vec(self, vec):
722+
# compute `A.T @ vec` in O(n) instead of O(n^2)
703723
out = np.zeros_like(vec)
704724
n_H = self.H_indptr.shape[0] - 1
705725

@@ -712,3 +732,21 @@ def _AT_dot_vec(self, vec):
712732
out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H)
713733

714734
return out
735+
736+
def _get_indptr(self, vals, indices):
737+
# given `indices = argsort(vals)`
738+
# build and array `indptr` where two consecutive elements
739+
# delimit indices with the same val
740+
n_indices = indices.shape[0]
741+
742+
indptr = [0]
743+
count = 1
744+
for i in range(n_indices - 1):
745+
if vals[indices[i]] == vals[indices[i+1]]:
746+
count += 1
747+
else:
748+
indptr.append(count + indptr[-1])
749+
count = 1
750+
indptr.append(n_indices)
751+
752+
return np.asarray(indptr, dtype=np.int64)

0 commit comments

Comments
 (0)