@@ -52,7 +52,7 @@ def train(
5252 update_b_freq : int = 5 ,
5353 ftol_b : float = 1e-8 ,
5454 lr_b : float = 1e-2 ,
55- max_iter_b : int = 100 ,
55+ max_iter_b : int = 1000 ,
5656 nproc : int = 3 ,
5757 ** kwargs
5858 ):
@@ -250,7 +250,7 @@ def iwls_step(
250250
251251 :return: (inferred param x features)
252252 """
253- w = self .model .fim_weight_j (j = idx_update ) # (observations x features)
253+ w = self .model .fim_weight_aa_j (j = idx_update ) # (observations x features)
254254 ybar = self .model .ybar_j (j = idx_update ) # (observations x features)
255255 # Translate to problem of form ax = b for each feature:
256256 # (in the following, X=design and Y=counts)
@@ -368,6 +368,7 @@ def _b_step_gd(
368368
369369 def optim_handle (
370370 self ,
371+ b_j ,
371372 data_j ,
372373 eta_loc_j ,
373374 xh_scale ,
@@ -381,17 +382,33 @@ def optim_handle(
381382 data_j = np .expand_dims (data_j , axis = - 1 )
382383
383384 ll = self .model .ll_handle ()
385+ lb , ub = self .model .param_bounds (dtype = data_j .dtype )
386+ lb_bracket = np .max ([lb ["b_var" ], b_j - 20 ])
387+ ub_bracket = np .min ([ub ["b_var" ], b_j + 20 ])
384388
385389 def cost_b_var (x , data_jj , eta_loc_jj , xh_scale_jj ):
386- x = np .array ([[x ]])
390+ x = np .clip ( np . array ([[x ]]), lb [ "b_var" ], ub [ "b_var" ])
387391 return - np .sum (ll (data_jj , eta_loc_jj , x , xh_scale_jj ))
388392
393+ # jac_b = self.model.jac_b_handle()
394+ # def cost_b_var_prime(x, data_jj, eta_loc_jj, xh_scale_jj):
395+ # x = np.clip(np.array([[x]]), lb["b_var"], ub["b_var"])
396+ # return - np.sum(jac_b(data_jj, eta_loc_jj, x, xh_scale_jj))
397+ # return scipy.optimize.line_search(
398+ # f=cost_b_var,
399+ # myfprime=cost_b_var_prime,
400+ # args=(data_j, eta_loc_j, xh_scale),
401+ # maxiter=max_iter,
402+ # xk=b_j+5,
403+ # pk=-np.ones_like(b_j)
404+ # )
405+
389406 return scipy .optimize .brent (
390407 func = cost_b_var ,
391408 args = (data_j , eta_loc_j , xh_scale ),
392409 maxiter = max_iter ,
393410 tol = ftol ,
394- brack = (- 5 , 5 ),
411+ brack = (lb_bracket , ub_bracket ),
395412 full_output = True
396413 )
397414
@@ -407,13 +424,13 @@ def _b_step_loop(
407424
408425 :return:
409426 """
410- x0 = - 10
411427 delta_theta = np .zeros_like (self .model .b_var )
412428 if isinstance (delta_theta , dask .array .core .Array ):
413429 delta_theta = delta_theta .compute ()
414430
415431 xh_scale = np .matmul (self .model .design_scale , self .model .constraints_scale ).compute ()
416- if nproc > 1 :
432+ b_var = self .model .b_var .compute ()
433+ if nproc > 1 and len (idx_update ) > nproc :
417434 sys .stdout .write ('\r Fitting %i dispersion models: (progress not available with multiprocessing)' % len (idx_update ))
418435 sys .stdout .flush ()
419436 with multiprocessing .Pool (processes = nproc ) as pool :
@@ -422,6 +439,7 @@ def _b_step_loop(
422439 results = pool .starmap (
423440 self .optim_handle ,
424441 [(
442+ b_var [0 , j ],
425443 x [:, [j ]],
426444 eta_loc [:, [j ]],
427445 xh_scale ,
@@ -452,12 +470,16 @@ def _b_step_loop(
452470 data = data .todense ()
453471
454472 ll = self .model .ll_handle ()
473+ lb , ub = self .model .param_bounds (dtype = data .dtype )
474+ lb_bracket = np .max ([lb ["b_var" ], b_var [0 , j ] - 20 ])
475+ ub_bracket = np .min ([ub ["b_var" ], b_var [0 , j ] + 20 ])
455476
456477 def cost_b_var (x , data_j , eta_loc_j , xh_scale_j ):
478+ x = np .clip (np .array ([[x ]]), lb ["b_var" ], ub ["b_var" ])
457479 return - np .sum (ll (
458480 data_j ,
459481 eta_loc_j ,
460- np . array ([[ x ]]) ,
482+ x ,
461483 xh_scale_j
462484 ))
463485
@@ -466,7 +488,7 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
466488 args = (data , eta_loc , xh_scale ),
467489 maxiter = max_iter ,
468490 tol = ftol ,
469- brack = (- 5 , 5 ),
491+ brack = (lb_bracket , ub_bracket ),
470492 full_output = False
471493 )
472494 else :
@@ -489,8 +511,7 @@ def finalize(self):
489511 transfers relevant attributes.
490512 """
491513 # Read from numpy-IRLS estimator specific model:
492-
493- self ._hessian = self .model .hessian .compute ()
514+ self ._hessian = - self .model .fim .compute ()
494515 self ._fisher_inv = np .linalg .inv (- self ._hessian )
495516 self ._jacobian = np .sum (np .abs (self .model .jac .compute () / self .model .x .shape [0 ]), axis = 1 )
496517 self ._log_likelihood = self .model .ll_byfeature .compute ()
0 commit comments