@@ -106,134 +106,134 @@ def init_par(
106106 shape = [input_data .num_observations , input_data .num_features ]
107107 )
108108
109- groupwise_means = None
110- init_a_str = None
111- if isinstance (init_a , str ):
112- init_a_str = init_a .lower ()
113- # Chose option if auto was chosen
114- if init_a .lower () == "auto" :
115- init_a = "closed_form"
116-
117- if init_a .lower () == "closed_form" :
118- #try:
119- groupwise_means , init_a , rmsd_a = closedform_nb_glm_logmu (
120- X = input_data .X ,
121- design_loc = input_data .design_loc ,
122- constraints_loc = input_data .constraints_loc .values ,
123- size_factors = size_factors_init ,
124- link_fn = lambda mu : np .log (self .np_clip_param (mu , "mu" ))
125- )
126-
127- # train mu, if the closed-form solution is inaccurate
128- self ._train_loc = not np .all (rmsd_a == 0 )
129-
130- if input_data .size_factors is not None :
131- if np .any (input_data .size_factors != 1 ):
132- self ._train_loc = True
133-
134- logger .debug ("Using closed-form MLE initialization for mean" )
135- logger .debug ("Should train mu: %s" , self ._train_loc )
136- #except np.linalg.LinAlgError:
137- # logger.warning("Closed form initialization failed!")
138- elif init_a .lower () == "standard" :
139- if isinstance (input_data .X , SparseXArrayDataArray ):
140- overall_means = input_data .X .mean (dim = "observations" )
141- else :
142- overall_means = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
143- overall_means = self .np_clip_param (overall_means , "mu" )
144-
145- init_a = np .zeros ([input_data .num_loc_params , input_data .num_features ])
146- init_a [0 , :] = np .log (overall_means )
147- self ._train_loc = True
148-
149- logger .debug ("Using standard initialization for mean" )
150- logger .debug ("Should train mu: %s" , self ._train_loc )
151- elif init_a .lower () == "all_zero" :
152- init_a = np .zeros ([input_data .num_loc_params , input_data .num_features ])
153- self ._train_loc = True
154-
155- logger .debug ("Using all_zero initialization for mean" )
156- logger .debug ("Should train mu: %s" , self ._train_loc )
157- else :
158- raise ValueError ("init_a string %s not recognized" % init_a )
159-
160- if isinstance (init_b , str ):
161- if init_b .lower () == "auto" :
162- init_b = "standard"
163-
164- if init_b .lower () == "closed_form" or init_b .lower () == "standard" :
165- #try:
166- # Check whether it is necessary to recompute group-wise means.
167- dmats_unequal = False
168- if input_data .design_loc .shape [1 ] == input_data .design_scale .shape [1 ]:
169- if np .any (input_data .design_loc .values != input_data .design_scale .values ):
170- dmats_unequal = True
171-
172- inits_unequal = False
173- if init_a_str is not None :
174- if init_a_str != init_b :
175- inits_unequal = True
176-
177- if inits_unequal or dmats_unequal :
178- groupwise_means = None
179-
180- # Watch out: init_mu is full obs x features matrix and is very large in many cases.
181- if inits_unequal or dmats_unequal :
109+ if init_model is None :
110+ groupwise_means = None
111+ init_a_str = None
112+ if isinstance (init_a , str ):
113+ init_a_str = init_a .lower ()
114+ # Chose option if auto was chosen
115+ if init_a .lower () == "auto" :
116+ init_a = "closed_form"
117+
118+ if init_a .lower () == "closed_form" :
119+ #try:
120+ groupwise_means , init_a , rmsd_a = closedform_nb_glm_logmu (
121+ X = input_data .X ,
122+ design_loc = input_data .design_loc ,
123+ constraints_loc = input_data .constraints_loc .values ,
124+ size_factors = size_factors_init ,
125+ link_fn = lambda mu : np .log (self .np_clip_param (mu , "mu" ))
126+ )
127+
128+ # train mu, if the closed-form solution is inaccurate
129+ self ._train_loc = not np .all (rmsd_a == 0 )
130+
131+ if input_data .size_factors is not None :
132+ if np .any (input_data .size_factors != 1 ):
133+ self ._train_loc = True
134+
135+ logger .debug ("Using closed-form MLE initialization for mean" )
136+ logger .debug ("Should train mu: %s" , self ._train_loc )
137+ #except np.linalg.LinAlgError:
138+ # logger.warning("Closed form initialization failed!")
139+ elif init_a .lower () == "standard" :
182140 if isinstance (input_data .X , SparseXArrayDataArray ):
183- init_mu = np .matmul (
184- input_data .design_loc .values ,
185- np .matmul (input_data .constraints_loc .values , init_a )
186- )
141+ overall_means = input_data .X .mean (dim = "observations" )
187142 else :
188- init_a_xr = data_utils .xarray_from_data (init_a , dims = ("loc_params" , "features" ))
189- init_a_xr .coords ["loc_params" ] = input_data .constraints_loc .coords ["loc_params" ]
190- init_mu = input_data .design_loc .dot (input_data .constraints_loc .dot (init_a_xr ))
143+ overall_means = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
144+ overall_means = self .np_clip_param (overall_means , "mu" )
191145
192- if size_factors_init is not None :
193- init_mu = init_mu + np .log (size_factors_init )
194- init_mu = np .exp (init_mu )
195- else :
196- init_mu = None
146+ init_a = np .zeros ([input_data .num_loc_params , input_data .num_features ])
147+ init_a [0 , :] = np .log (overall_means )
148+ self ._train_loc = True
197149
198- if init_b .lower () == "closed_form" :
199- groupwise_scales , init_b , rmsd_b = closedform_nb_glm_logphi (
200- X = input_data .X ,
201- mu = init_mu ,
202- design_scale = input_data .design_scale ,
203- constraints = input_data .constraints_scale .values ,
204- size_factors = size_factors_init ,
205- groupwise_means = groupwise_means ,
206- link_fn = lambda r : np .log (self .np_clip_param (r , "r" ))
207- )
150+ logger .debug ("Using standard initialization for mean" )
151+ logger .debug ("Should train mu: %s" , self ._train_loc )
152+ elif init_a .lower () == "all_zero" :
153+ init_a = np .zeros ([input_data .num_loc_params , input_data .num_features ])
154+ self ._train_loc = True
208155
209- logger .debug ("Using closed-form MME initialization for dispersion" )
210- logger .debug ("Should train r: %s" , self ._train_scale )
211- elif init_b .lower () == "standard" :
212- groupwise_scales , init_b_intercept , rmsd_b = closedform_nb_glm_logphi (
213- X = input_data .X ,
214- mu = init_mu ,
215- design_scale = input_data .design_scale [:,[0 ]],
216- constraints = input_data .constraints_scale [[0 ], [0 ]].values ,
217- size_factors = size_factors_init ,
218- groupwise_means = None ,
219- link_fn = lambda r : np .log (self .np_clip_param (r , "r" ))
220- )
156+ logger .debug ("Using all_zero initialization for mean" )
157+ logger .debug ("Should train mu: %s" , self ._train_loc )
158+ else :
159+ raise ValueError ("init_a string %s not recognized" % init_a )
160+
161+ if isinstance (init_b , str ):
162+ if init_b .lower () == "auto" :
163+ init_b = "standard"
164+
165+ if init_b .lower () == "closed_form" or init_b .lower () == "standard" :
166+ #try:
167+ # Check whether it is necessary to recompute group-wise means.
168+ dmats_unequal = False
169+ if input_data .design_loc .shape [1 ] == input_data .design_scale .shape [1 ]:
170+ if np .any (input_data .design_loc .values != input_data .design_scale .values ):
171+ dmats_unequal = True
172+
173+ inits_unequal = False
174+ if init_a_str is not None :
175+ if init_a_str != init_b :
176+ inits_unequal = True
177+
178+ if inits_unequal or dmats_unequal :
179+ groupwise_means = None
180+
181+ # Watch out: init_mu is full obs x features matrix and is very large in many cases.
182+ if inits_unequal or dmats_unequal :
183+ if isinstance (input_data .X , SparseXArrayDataArray ):
184+ init_mu = np .matmul (
185+ input_data .design_loc .values ,
186+ np .matmul (input_data .constraints_loc .values , init_a )
187+ )
188+ else :
189+ init_a_xr = data_utils .xarray_from_data (init_a , dims = ("loc_params" , "features" ))
190+ init_a_xr .coords ["loc_params" ] = input_data .constraints_loc .coords ["loc_params" ]
191+ init_mu = input_data .design_loc .dot (input_data .constraints_loc .dot (init_a_xr ))
192+
193+ if size_factors_init is not None :
194+ init_mu = init_mu + np .log (size_factors_init )
195+ init_mu = np .exp (init_mu )
196+ else :
197+ init_mu = None
198+
199+ if init_b .lower () == "closed_form" :
200+ groupwise_scales , init_b , rmsd_b = closedform_nb_glm_logphi (
201+ X = input_data .X ,
202+ mu = init_mu ,
203+ design_scale = input_data .design_scale ,
204+ constraints = input_data .constraints_scale .values ,
205+ size_factors = size_factors_init ,
206+ groupwise_means = groupwise_means ,
207+ link_fn = lambda r : np .log (self .np_clip_param (r , "r" ))
208+ )
209+
210+ logger .debug ("Using closed-form MME initialization for dispersion" )
211+ logger .debug ("Should train r: %s" , self ._train_scale )
212+ elif init_b .lower () == "standard" :
213+ groupwise_scales , init_b_intercept , rmsd_b = closedform_nb_glm_logphi (
214+ X = input_data .X ,
215+ mu = init_mu ,
216+ design_scale = input_data .design_scale [:,[0 ]],
217+ constraints = input_data .constraints_scale [[0 ], [0 ]].values ,
218+ size_factors = size_factors_init ,
219+ groupwise_means = None ,
220+ link_fn = lambda r : np .log (self .np_clip_param (r , "r" ))
221+ )
222+ init_b = np .zeros ([input_data .num_scale_params , input_data .X .shape [1 ]])
223+ init_b [0 , :] = init_b_intercept
224+
225+ logger .debug ("Using closed-form MME initialization for dispersion" )
226+ logger .debug ("Should train r: %s" , self ._train_scale )
227+ #except np.linalg.LinAlgError:
228+ # logger.warning("Closed form initialization failed!")
229+ elif init_b .lower () == "all_zero" :
221230 init_b = np .zeros ([input_data .num_scale_params , input_data .X .shape [1 ]])
222- init_b [0 , :] = init_b_intercept
223231
224- logger .debug ("Using closed-form MME initialization for dispersion" )
232+ logger .debug ("Using standard initialization for dispersion" )
225233 logger .debug ("Should train r: %s" , self ._train_scale )
226- #except np.linalg.LinAlgError:
227- # logger.warning("Closed form initialization failed!")
228- elif init_b .lower () == "all_zero" :
229- init_b = np .zeros ([input_data .num_scale_params , input_data .X .shape [1 ]])
230-
231- logger .debug ("Using standard initialization for dispersion" )
232- logger .debug ("Should train r: %s" , self ._train_scale )
233- else :
234- raise ValueError ("init_b string %s not recognized" % init_b )
235-
236- if init_model is not None :
234+ else :
235+ raise ValueError ("init_b string %s not recognized" % init_b )
236+ else :
237237 # Locations model:
238238 if isinstance (init_a , str ) and (init_a .lower () == "auto" or init_a .lower () == "init_model" ):
239239 my_loc_names = set (input_data .design_loc_names .values )
0 commit comments