@@ -187,7 +187,7 @@ def init_par(
187187 )
188188 else :
189189 init_a_xr = data_utils .xarray_from_data (init_a , dims = ("loc_params" , "features" ))
190- init_a_xr .coords ["loc_params" ] = input_data .constraints_loc .coords ["loc_params" ]
190+ init_a_xr .coords ["loc_params" ] = input_data .constraints_loc .coords ["loc_params" ]. values
191191 init_mu = input_data .design_loc .dot (input_data .constraints_loc .dot (init_a_xr ))
192192
193193 if size_factors_init is not None :
@@ -236,28 +236,28 @@ def init_par(
236236 else :
237237 # Locations model:
238238 if isinstance (init_a , str ) and (init_a .lower () == "auto" or init_a .lower () == "init_model" ):
239- my_loc_names = set (input_data .design_loc_names .values )
240- my_loc_names = my_loc_names .intersection (init_model .input_data .design_loc_names .values )
239+ my_loc_names = set (input_data .loc_names .values )
240+ my_loc_names = my_loc_names .intersection (set ( init_model .input_data .loc_names .values ) )
241241
242242 init_loc = np .zeros ([input_data .num_loc_params , input_data .num_features ])
243243 for parm in my_loc_names :
244- init_idx = np .where (init_model .input_data .design_loc_names == parm )
245- my_idx = np .where (input_data .design_loc_names == parm )
246- init_loc [my_idx ] = init_model .par_link_loc [init_idx ]
244+ init_idx = np .where (init_model .input_data .loc_names == parm )[ 0 ]
245+ my_idx = np .where (input_data .loc_names == parm )[ 0 ]
246+ init_loc [my_idx ] = init_model .a_var [init_idx ]
247247
248248 init_a = init_loc
249249 logger .debug ("Using initialization based on input model for mean" )
250250
251251 # Scale model:
252252 if isinstance (init_b , str ) and (init_b .lower () == "auto" or init_b .lower () == "init_model" ):
253- my_scale_names = set (input_data .design_scale_names .values )
254- my_scale_names = my_scale_names .intersection (init_model .input_data .design_scale_names .values )
253+ my_scale_names = set (input_data .scale_names .values )
254+ my_scale_names = my_scale_names .intersection (init_model .input_data .scale_names .values )
255255
256256 init_scale = np .zeros ([input_data .num_scale_params , input_data .num_features ])
257257 for parm in my_scale_names :
258- init_idx = np .where (init_model .input_data .design_scale_names == parm )
259- my_idx = np .where (input_data .design_scale_names == parm )
260- init_scale [my_idx ] = init_model .par_link_scale [init_idx ]
258+ init_idx = np .where (init_model .input_data .scale_names == parm )[ 0 ]
259+ my_idx = np .where (input_data .scale_names == parm )[ 0 ]
260+ init_scale [my_idx ] = init_model .b_var [init_idx ]
261261
262262 init_b = init_scale
263263 logger .debug ("Using initialization based on input model for dispersion" )
0 commit comments