1919from .base import param_bounds , tf_clip_param , np_clip_param , apply_constraints
2020
2121from .external import AbstractEstimator , XArrayEstimatorStore , InputData , Model , MonitoredTFEstimator , TFEstimatorGraph
22- from .external import nb_utils , train_utils , op_utils , rand_utils
22+ from .external import nb_utils , train_utils , op_utils , rand_utils , data_utils
2323from .external import pkg_constants
2424from .hessians import Hessians
2525from .jacobians import Jacobians
@@ -772,6 +772,9 @@ def __init__(
772772 shape = [input_data .num_observations , input_data .num_features ]
773773 )
774774
775+ groupwise_means = None # [groups, features]
776+ overall_means = None # [1, features]
777+ logger .debug (" * Initialize mean model" )
775778 if isinstance (init_a , str ):
776779 # Chose option if auto was chosen
777780 if init_a .lower () == "auto" :
@@ -797,12 +800,12 @@ def __init__(
797800 if size_factors_init is not None :
798801 X = np .divide (X , size_factors_init )
799802
800- mean = X .groupby ("group" ).mean (dim = "observations" ).values
803+ groupwise_means = X .groupby ("group" ).mean (dim = "observations" ).values
801804 # clipping
802- mean = np_clip_param (mean , "mu" )
805+ groupwise_means = np_clip_param (groupwise_means , "mu" )
803806 # mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
804807
805- a = np .log (mean )
808+ a = np .log (groupwise_means )
806809 if input_data .constraints_loc is not None :
807810 a_constraints = np .zeros ([input_data .constraints_loc .shape [0 ], a .shape [1 ]])
808811 # Add constraints (sum to zero) to value vector to remove structural unidentifiability.
@@ -812,7 +815,9 @@ def __init__(
812815 # inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
813816 # init_a = np.matmul(inv_design, a)
814817 #
815- # Better option: use least-squares solver to calculate a'
818+ # Use least-squares solver to calculate a':
819+ # This is faster and more accurate than using matrix inversion.
820+ logger .debug (" ** Solve lstsq problem" )
816821 a_prime = np .linalg .lstsq (unique_design_loc , a , rcond = None )
817822 init_a = a_prime [0 ]
818823 # stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
@@ -830,18 +835,19 @@ def __init__(
830835 except np .linalg .LinAlgError :
831836 logger .warning ("Closed form initialization failed!" )
832837 elif init_a .lower () == "standard" :
833- mean = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
838+ overall_means = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
834839 # clipping
835- mean = np_clip_param (mean , "mu" )
840+ overall_means = np_clip_param (overall_means , "mu" )
836841 # mean = np.nextafter(0, 1, out=mean, where=mean == 0, dtype=mean.dtype)
837842
838843 init_a = np .zeros ([input_data .num_design_loc_params , input_data .num_features ])
839- init_a [0 , :] = np .log (mean )
844+ init_a [0 , :] = np .log (overall_means )
840845 self ._train_mu = True
841846
842847 logger .info ("Using standard initialization for mean" )
843848 logger .info ("Should train mu: %s" , self ._train_mu )
844849
850+ logger .debug (" * Initialize dispersion model" )
845851 if isinstance (init_b , str ):
846852 if init_b .lower () == "auto" :
847853 init_b = "closed_form"
@@ -868,14 +874,20 @@ def __init__(
868874 if input_data .size_factors is not None :
869875 X = np .divide (X , size_factors_init )
870876
871- Xdiff = X - np .exp (input_data .design_loc @ init_a )
877+ #Xdiff = X - np.exp(input_data.design_loc @ init_a)
878+ # Define xarray version of init so that Xdiff can be evaluated lazy by dask.
879+ init_a_xr = data_utils .xarray_from_data (init_a , dims = ("design_loc_params" , "features" ))
880+ init_a_xr .coords ["design_loc_params" ] = input_data .design_loc .coords ["design_loc_params" ]
881+ logger .debug (" ** Define Xdiff" )
882+ Xdiff = X - np .exp (input_data .design_loc .dot (init_a_xr ))
872883 variance = np .square (Xdiff ).groupby ("group" ).mean (dim = "observations" )
873884
874- group_mean = X .groupby ("group" ).mean (dim = "observations" )
875- denominator = np .fmax (variance - group_mean , 0 )
885+ if groupwise_means is None :
886+ groupwise_means = X .groupby ("group" ).mean (dim = "observations" )
887+ denominator = np .fmax (variance - groupwise_means , 0 )
876888 denominator = np .nextafter (0 , 1 , out = denominator .values , where = denominator == 0 ,
877889 dtype = denominator .dtype )
878- r = np .asarray (np .square (group_mean ) / denominator )
890+ r = np .asarray (np .square (groupwise_means ) / denominator )
879891 # clipping
880892 r = np_clip_param (r , "r" )
881893 # r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
@@ -891,7 +903,9 @@ def __init__(
891903 # inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
892904 # init_b = np.matmul(inv_design, b)
893905 #
894- # Better option: use least-squares solver to calculate b''
906+ # Use least-squares solver to calculate a':
907+ # This is faster and more accurate than using matrix inversion.
908+ logger .debug (" ** Solve lstsq problem" )
895909 b_prime = np .linalg .lstsq (unique_design_scale , b , rcond = None )
896910 init_b = b_prime [0 ]
897911
@@ -1012,6 +1026,7 @@ def fetch_fn(idx):
10121026 else :
10131027 init_b = init_b .astype (dtype )
10141028
1029+ logger .debug (" * Start creating model" )
10151030 with graph .as_default ():
10161031 # create model
10171032 model = EstimatorGraph (
@@ -1030,6 +1045,7 @@ def fetch_fn(idx):
10301045 extended_summary = extended_summary ,
10311046 dtype = dtype
10321047 )
1048+ logger .debug (" * Finished creating model" )
10331049
10341050 MonitoredTFEstimator .__init__ (self , model )
10351051
0 commit comments