1
1
import numpy as np
2
2
from numpy .linalg import norm
3
3
from numba import njit
4
- from numba import float64
4
+ from numba import float64 , int64 , bool_
5
5
6
6
from skglm .datafits .base import BaseDatafit
7
7
from skglm .utils .sparse_ops import spectral_norm
@@ -547,90 +547,100 @@ def intercept_update_self(self, y, Xw):
547
547
548
548
549
549
class Cox (BaseDatafit ):
550
- r"""Cox datafit for survival analysis with Breslow estimate .
550
+ r"""Cox datafit for survival analysis.
551
551
552
- The datafit reads [1]
553
-
554
- .. math::
555
-
556
- 1 / n_"samples" \sum_(i=1)^(n_"samples") -s_i \langle x_i, w \rangle
557
- + \log (\sum_(j | y_j \geq y_i) e^{\langle x_i, w \rangle})
558
-
559
- where :math:`s_i` indicates the sample censorship and :math:`tm`
560
- is the vector recording the time of event occurrences.
561
-
562
- Defining the matrix :math:`B` with
563
- :math:`B_{i,j} = 1` if :math:`tm_j \geq tm_i` and :math:`0` otherwise,
564
- the datafit can be rewritten in the following compact form
565
-
566
- .. math::
567
-
568
- 1 / n_"samples" \langle s, Xw \rangle
569
- + 1 / n_"samples" \langle s, \log B e^{Xw} \rangle
552
+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>` for details.
570
553
554
+ Parameters
555
+ ----------
556
+ use_efron : bool, default=False
557
+ If ``True`` uses Efron estimate to handle tied observations.
571
558
572
559
Attributes
573
560
----------
574
561
B : array-like, shape (n_samples, n_samples)
575
562
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
576
- if ``tm[j] >= tm[i]`` and `0 ` otherwise. This matrix is initialized
563
+ if ``tm[j] >= tm[i]`` and ``0` ` otherwise. This matrix is initialized
577
564
using the ``.initialize`` method.
578
565
579
- References
580
- ----------
581
- .. [1] DY Lin. On the Breslow estimator.
582
- Lifetime data analysis, 13:471–480, 2007.
566
+ H_indices : array-like, shape (n_samples,)
567
+ Indices of observations with the same occurrence times stacked horizontally
568
+ as ``[group_1, group_2, ...]``. This array is initialized
569
+ when calling ``.initialize`` method when ``use_efron=True``.
570
+
571
+ H_indptr : array-like, (np.unique(tm) + 1,)
572
+ Array where two consecutive elements delimits a group of observations
573
+ having the same occurrence times.
583
574
"""
584
575
585
- def __init__ (self ):
586
- pass
576
+ def __init__ (self , use_efron = False ):
577
+ self . use_efron = use_efron
587
578
588
579
def get_spec (self ):
589
580
return (
581
+ ('use_efron' , bool_ ),
590
582
('B' , float64 [:, ::1 ]),
583
+ ('H_indptr' , int64 [:]),
584
+ ('H_indices' , int64 [:]),
591
585
)
592
586
593
587
def params_to_dict (self ):
594
- return dict ()
588
+ return dict (use_efron = self . use_efron )
595
589
596
590
def value (self , y , w , Xw ):
597
591
"""Compute the value of the datafit."""
598
592
tm , s = y
599
593
n_samples = Xw .shape [0 ]
600
594
601
- out = - (s @ Xw ) + s @ np .log (self .B @ np .exp (Xw ))
595
+ # compute inside log term
596
+ exp_Xw = np .exp (Xw )
597
+ B_exp_Xw = self .B @ exp_Xw
598
+ if self .use_efron :
599
+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
600
+
601
+ out = - (s @ Xw ) + s @ np .log (B_exp_Xw )
602
602
return out / n_samples
603
603
604
604
def raw_grad (self , y , Xw ):
605
605
r"""Compute gradient of datafit w.r.t. ``Xw``.
606
606
607
- The raw gradient reads
608
-
609
- (-s + exp_Xw * (B.T @ (s / B @ exp_Xw)) / n_samples
607
+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
608
+ equation 4 for details.
610
609
"""
611
610
tm , s = y
612
611
n_samples = Xw .shape [0 ]
613
612
614
613
exp_Xw = np .exp (Xw )
615
614
B_exp_Xw = self .B @ exp_Xw
615
+ if self .use_efron :
616
+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
617
+
618
+ s_over_B_exp_Xw = s / B_exp_Xw
619
+ out = - s + exp_Xw * (self .B .T @ (s_over_B_exp_Xw ))
620
+ if self .use_efron :
621
+ out -= exp_Xw * self ._AT_dot_vec (s_over_B_exp_Xw )
616
622
617
- out = - s + exp_Xw * (self .B .T @ (s / B_exp_Xw ))
618
623
return out / n_samples
619
624
620
625
def raw_hessian (self , y , Xw ):
621
626
"""Compute a diagonal upper bound of the datafit's Hessian w.r.t. ``Xw``.
622
627
623
- The diagonal upper bound reads
624
-
625
- exp_Xw * (B.T @ s / B_exp_Xw) / n_samples
628
+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
629
+ equation 6 for details.
626
630
"""
627
631
tm , s = y
628
632
n_samples = Xw .shape [0 ]
629
633
630
634
exp_Xw = np .exp (Xw )
631
635
B_exp_Xw = self .B @ exp_Xw
636
+ if self .use_efron :
637
+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
638
+
639
+ s_over_B_exp_Xw = s / B_exp_Xw
640
+ out = exp_Xw * (self .B .T @ (s_over_B_exp_Xw ))
641
+ if self .use_efron :
642
+ out -= exp_Xw * self ._AT_dot_vec (s_over_B_exp_Xw )
632
643
633
- out = exp_Xw * (self .B .T @ (s / B_exp_Xw ))
634
644
return out / n_samples
635
645
636
646
def initialize (self , X , y ):
@@ -640,9 +650,58 @@ def initialize(self, X, y):
640
650
tm_as_col = tm .reshape ((- 1 , 1 ))
641
651
self .B = (tm >= tm_as_col ).astype (X .dtype )
642
652
653
+ if self .use_efron :
654
+ H_indices = np .argsort (tm )
655
+ # filter out censored data
656
+ H_indices = H_indices [s [H_indices ] != 0 ]
657
+ n_uncensored_samples = H_indices .shape [0 ]
658
+
659
+ # build H_indptr
660
+ H_indptr = [0 ]
661
+ count = 1
662
+ for i in range (1 , n_uncensored_samples ):
663
+ if tm [H_indices [i - 1 ]] == tm [H_indices [i ]]:
664
+ count += 1
665
+ else :
666
+ H_indptr .append (count + H_indptr [- 1 ])
667
+ count = 1
668
+ H_indptr .append (n_uncensored_samples )
669
+ H_indptr = np .asarray (H_indptr , dtype = np .int64 )
670
+
671
+ # save in instance
672
+ self .H_indptr = H_indptr
673
+ self .H_indices = H_indices
674
+
643
675
def initialize_sparse (self , X_data , X_indptr , X_indices , y ):
644
676
"""Initialize the datafit attributes in sparse dataset case."""
645
- tm , s = y
677
+ # initialize_sparse and initialize have the same implementation
678
+ # small hack to avoid repetitive code: pass in X_data as only its dtype is used
679
+ self .initialize (X_data , y )
646
680
647
- tm_as_col = tm .reshape ((- 1 , 1 ))
648
- self .B = (tm >= tm_as_col ).astype (X_data .dtype )
681
+ def _A_dot_vec (self , vec ):
682
+ out = np .zeros_like (vec )
683
+ n_H = self .H_indptr .shape [0 ] - 1
684
+
685
+ for idx in range (n_H ):
686
+ current_H_idx = self .H_indices [self .H_indptr [idx ]: self .H_indptr [idx + 1 ]]
687
+ size_current_H = current_H_idx .shape [0 ]
688
+ frac_range = np .arange (size_current_H , dtype = vec .dtype ) / size_current_H
689
+
690
+ sum_vec_H = np .sum (vec [current_H_idx ])
691
+ out [current_H_idx ] = sum_vec_H * frac_range
692
+
693
+ return out
694
+
695
+ def _AT_dot_vec (self , vec ):
696
+ out = np .zeros_like (vec )
697
+ n_H = self .H_indptr .shape [0 ] - 1
698
+
699
+ for idx in range (n_H ):
700
+ current_H_idx = self .H_indices [self .H_indptr [idx ]: self .H_indptr [idx + 1 ]]
701
+ size_current_H = current_H_idx .shape [0 ]
702
+ frac_range = np .arange (size_current_H , dtype = vec .dtype ) / size_current_H
703
+
704
+ weighted_sum_vec_H = vec [current_H_idx ] @ frac_range
705
+ out [current_H_idx ] = weighted_sum_vec_H * np .ones (size_current_H )
706
+
707
+ return out
0 commit comments