1818from .base import param_bounds , tf_clip_param , np_clip_param , apply_constraints
1919
2020from .external import AbstractEstimator , XArrayEstimatorStore , InputData , Model , MonitoredTFEstimator , TFEstimatorGraph
21- from .external import nb_utils , train_utils , op_utils , rand_utils , data_utils
21+ from .external import nb_utils , train_utils , op_utils , rand_utils , data_utils , nb_glm_utils
2222from .external import pkg_constants
2323from .hessians import Hessians
2424from .jacobians import Jacobians
@@ -759,55 +759,23 @@ def __init__(
759759
760760 if init_a .lower () == "closed_form" :
761761 try :
762- unique_design_loc , inverse_idx = np .unique (input_data .design_loc , axis = 0 , return_inverse = True )
763- if input_data .constraints_loc is not None :
764- unique_design_loc_constraints = input_data .constraints_loc .copy ()
765- # -1 in the constraint matrix is used to indicate which variable
766- # is made dependent so that the constrained is fullfilled.
767- # This has to be rewritten here so that the design matrix is full rank
768- # which is necessary so that it can be inverted for parameter
769- # initialisation.
770- unique_design_loc_constraints [unique_design_loc_constraints == - 1 ] = 1
771- # Add constraints into design matrix to remove structural unidentifiability.
772- unique_design_loc = np .vstack ([unique_design_loc , unique_design_loc_constraints ])
773-
774- if unique_design_loc .shape [1 ] > np .linalg .matrix_rank (unique_design_loc ):
775- logger .warning ("Location model is not full rank!" )
776- X = input_data .X .assign_coords (group = (("observations" ,), inverse_idx ))
777- if size_factors_init is not None :
778- X = np .divide (X , size_factors_init )
779-
780- groupwise_means = X .groupby ("group" ).mean (dim = "observations" ).values
781- # clipping
782- groupwise_means = np_clip_param (groupwise_means , "mu" )
783- # mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
784-
785- a = np .log (groupwise_means )
786- if input_data .constraints_loc is not None :
787- a_constraints = np .zeros ([input_data .constraints_loc .shape [0 ], a .shape [1 ]])
788- # Add constraints (sum to zero) to value vector to remove structural unidentifiability.
789- a = np .vstack ([a , a_constraints ])
790-
791- # inv_design = np.linalg.pinv(unique_design_loc) # NOTE: this is numerically inaccurate!
792- # inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
793- # init_a = np.matmul(inv_design, a)
794- #
795- # Use least-squares solver to calculate a':
796- # This is faster and more accurate than using matrix inversion.
797- logger .debug (" ** Solve lstsq problem" )
798- a_prime = np .linalg .lstsq (unique_design_loc , a , rcond = None )
799- init_a = a_prime [0 ]
800- # stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
762+ groupwise_means , init_a , rmsd_a = nb_glm_utils .closedform_nb_glm_logmu (
763+ X = input_data .X ,
764+ design_loc = input_data .design_loc ,
765+ constraints = input_data .constraints_loc ,
766+ size_factors = size_factors_init ,
767+ link_fn = lambda mu : np .log (np_clip_param (mu , "mu" ))
768+ )
801769
802770 # train mu, if the closed-form solution is inaccurate
803- self ._train_mu = not np .all (a_prime [ 1 ] == 0 )
771+ self ._train_mu = not np .all (rmsd_a == 0 )
804772
805773 # Temporal fix: train mu if size factors are given as closed form may be different:
806774 if input_data .size_factors is not None :
807775 self ._train_mu = True
808776
809777 logger .info ("Using closed-form MLE initialization for mean" )
810- logger .debug ("RMSE of closed-form mean:\n %s" , a_prime [ 1 ] )
778+ logger .debug ("RMSE of closed-form mean:\n %s" , rmsd_a )
811779 logger .info ("Should train mu: %s" , self ._train_mu )
812780 except np .linalg .LinAlgError :
813781 logger .warning ("Closed form initialization failed!" )
@@ -831,63 +799,22 @@ def __init__(
831799
832800 if init_b .lower () == "closed_form" :
833801 try :
834- unique_design_scale , inverse_idx = np .unique (input_data .design_scale , axis = 0 ,
835- return_inverse = True )
836- if input_data .constraints_scale is not None :
837- unique_design_scale_constraints = input_data .constraints_scale .copy ()
838- # -1 in the constraint matrix is used to indicate which variable
839- # is made dependent so that the constrained is fullfilled.
840- # This has to be rewritten here so that the design matrix is full rank
841- # which is necessary so that it can be inverted for parameter
842- # initialisation.
843- unique_design_scale_constraints [unique_design_scale_constraints == - 1 ] = 1
844- # Add constraints into design matrix to remove structural unidentifiability.
845- unique_design_scale = np .vstack ([unique_design_scale , unique_design_scale_constraints ])
846-
847- if unique_design_scale .shape [1 ] > np .linalg .matrix_rank (unique_design_scale ):
848- logger .warning ("Scale model is not full rank!" )
849-
850- X = input_data .X .assign_coords (group = (("observations" ,), inverse_idx ))
851- if input_data .size_factors is not None :
852- X = np .divide (X , size_factors_init )
853-
854- # Xdiff = X - np.exp(input_data.design_loc @ init_a)
855- # Define xarray version of init so that Xdiff can be evaluated lazy by dask.
856802 init_a_xr = data_utils .xarray_from_data (init_a , dims = ("design_loc_params" , "features" ))
857803 init_a_xr .coords ["design_loc_params" ] = input_data .design_loc .coords ["design_loc_params" ]
858- logger .debug (" ** Define Xdiff" )
859- Xdiff = X - np .exp (input_data .design_loc .dot (init_a_xr ))
860- variance = np .square (Xdiff ).groupby ("group" ).mean (dim = "observations" )
861-
862- if groupwise_means is None :
863- groupwise_means = X .groupby ("group" ).mean (dim = "observations" )
864- denominator = np .fmax (variance - groupwise_means , 0 )
865- denominator = np .nextafter (0 , 1 , out = denominator .values , where = denominator == 0 ,
866- dtype = denominator .dtype )
867- r = np .asarray (np .square (groupwise_means ) / denominator )
868- # clipping
869- r = np_clip_param (r , "r" )
870- # r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
871- # r = np.fmin(r, np.finfo(r.dtype).max)
872-
873- b = np .log (r )
874- if input_data .constraints_scale is not None :
875- b_constraints = np .zeros ([input_data .constraints_scale .shape [0 ], b .shape [1 ]])
876- # Add constraints (sum to zero) to value vector to remove structural unidentifiability.
877- b = np .vstack ([b , b_constraints ])
878-
879- # inv_design = np.linalg.pinv(unique_design_scale) # NOTE: this is numerically inaccurate!
880- # inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
881- # init_b = np.matmul(inv_design, b)
882- #
883- # Use least-squares solver to calculate a':
884- # This is faster and more accurate than using matrix inversion.
885- logger .debug (" ** Solve lstsq problem" )
886- b_prime = np .linalg .lstsq (unique_design_scale , b , rcond = None )
887- init_b = b_prime [0 ]
804+ init_mu = np .exp (input_data .design_loc .dot (init_a_xr ))
805+
806+ groupwise_scales , init_b , rmsd_b = nb_glm_utils .closedform_nb_glm_logphi (
807+ X = input_data .X ,
808+ mu = init_mu ,
809+ design_scale = input_data .design_scale ,
810+ constraints = input_data .constraints_scale ,
811+ size_factors = size_factors_init ,
812+ groupwise_means = groupwise_means ,
813+ link_fn = lambda r : np .log (np_clip_param (r , "r" ))
814+ )
888815
889816 logger .info ("Using closed-form MME initialization for dispersion" )
890- logger .debug ("RMSE of closed-form dispersion:\n %s" , b_prime [ 1 ] )
817+ logger .debug ("RMSE of closed-form dispersion:\n %s" , rmsd_b )
891818 logger .info ("Should train r: %s" , self ._train_r )
892819 except np .linalg .LinAlgError :
893820 logger .warning ("Closed form initialization failed!" )
@@ -903,8 +830,11 @@ def __init__(
903830 my_loc_names = set (input_data .design_loc_names .values )
904831 my_loc_names = my_loc_names .intersection (init_model .input_data .design_loc_names .values )
905832
906- # Initialize new parameters to zero:
907- init_loc = np .zeros (shape = (input_data .num_design_loc_params , input_data .num_features ))
833+ init_loc = np .random .uniform (
834+ low = np .nextafter (0 , 1 , dtype = input_data .X .dtype ),
835+ high = np .sqrt (np .nextafter (0 , 1 , dtype = input_data .X .dtype )),
836+ size = (input_data .num_design_loc_params , input_data .num_features )
837+ )
908838 for parm in my_loc_names :
909839 init_idx = np .where (init_model .input_data .design_loc_names == parm )
910840 my_idx = np .where (input_data .design_loc_names == parm )
@@ -917,8 +847,11 @@ def __init__(
917847 my_scale_names = set (input_data .design_scale_names .values )
918848 my_scale_names = my_scale_names .intersection (init_model .input_data .design_scale_names .values )
919849
920- # Initialize new parameters to zero:
921- init_scale = np .zeros (shape = (input_data .num_design_scale_params , input_data .num_features ))
850+ init_scale = np .random .uniform (
851+ low = np .nextafter (0 , 1 , dtype = input_data .X .dtype ),
852+ high = np .sqrt (np .nextafter (0 , 1 , dtype = input_data .X .dtype )),
853+ size = (input_data .num_design_scale_params , input_data .num_features )
854+ )
922855 for parm in my_scale_names :
923856 init_idx = np .where (init_model .input_data .design_scale_names == parm )
924857 my_idx = np .where (input_data .design_scale_names == parm )
0 commit comments