Skip to content

Commit e699985

Browse files
adapted dispersion fitting in numpy, added optional code for wolfe linesearch
1 parent f50e79f commit e699985

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def train(
5252
update_b_freq: int = 5,
5353
ftol_b: float = 1e-8,
5454
lr_b: float = 1e-2,
55-
max_iter_b: int = 100,
55+
max_iter_b: int = 1000,
5656
nproc: int = 3,
5757
**kwargs
5858
):
@@ -250,7 +250,7 @@ def iwls_step(
250250
251251
:return: (inferred param x features)
252252
"""
253-
w = self.model.fim_weight_j(j=idx_update) # (observations x features)
253+
w = self.model.fim_weight_aa_j(j=idx_update) # (observations x features)
254254
ybar = self.model.ybar_j(j=idx_update) # (observations x features)
255255
# Translate to problem of form ax = b for each feature:
256256
# (in the following, X=design and Y=counts)
@@ -368,6 +368,7 @@ def _b_step_gd(
368368

369369
def optim_handle(
370370
self,
371+
b_j,
371372
data_j,
372373
eta_loc_j,
373374
xh_scale,
@@ -381,17 +382,33 @@ def optim_handle(
381382
data_j = np.expand_dims(data_j, axis=-1)
382383

383384
ll = self.model.ll_handle()
385+
lb, ub = self.model.param_bounds(dtype=data_j.dtype)
386+
lb_bracket = np.max([lb["b_var"], b_j - 20])
387+
ub_bracket = np.min([ub["b_var"], b_j + 20])
384388

385389
def cost_b_var(x, data_jj, eta_loc_jj, xh_scale_jj):
386-
x = np.array([[x]])
390+
x = np.clip(np.array([[x]]), lb["b_var"], ub["b_var"])
387391
return - np.sum(ll(data_jj, eta_loc_jj, x, xh_scale_jj))
388392

393+
# jac_b = self.model.jac_b_handle()
394+
# def cost_b_var_prime(x, data_jj, eta_loc_jj, xh_scale_jj):
395+
# x = np.clip(np.array([[x]]), lb["b_var"], ub["b_var"])
396+
# return - np.sum(jac_b(data_jj, eta_loc_jj, x, xh_scale_jj))
397+
# return scipy.optimize.line_search(
398+
# f=cost_b_var,
399+
# myfprime=cost_b_var_prime,
400+
# args=(data_j, eta_loc_j, xh_scale),
401+
# maxiter=max_iter,
402+
# xk=b_j+5,
403+
# pk=-np.ones_like(b_j)
404+
# )
405+
389406
return scipy.optimize.brent(
390407
func=cost_b_var,
391408
args=(data_j, eta_loc_j, xh_scale),
392409
maxiter=max_iter,
393410
tol=ftol,
394-
brack=(-5, 5),
411+
brack=(lb_bracket, ub_bracket),
395412
full_output=True
396413
)
397414

@@ -407,13 +424,13 @@ def _b_step_loop(
407424
408425
:return:
409426
"""
410-
x0 = -10
411427
delta_theta = np.zeros_like(self.model.b_var)
412428
if isinstance(delta_theta, dask.array.core.Array):
413429
delta_theta = delta_theta.compute()
414430

415431
xh_scale = np.matmul(self.model.design_scale, self.model.constraints_scale).compute()
416-
if nproc > 1:
432+
b_var = self.model.b_var.compute()
433+
if nproc > 1 and len(idx_update) > nproc:
417434
sys.stdout.write('\rFitting %i dispersion models: (progress not available with multiprocessing)' % len(idx_update))
418435
sys.stdout.flush()
419436
with multiprocessing.Pool(processes=nproc) as pool:
@@ -422,6 +439,7 @@ def _b_step_loop(
422439
results = pool.starmap(
423440
self.optim_handle,
424441
[(
442+
b_var[0, j],
425443
x[:, [j]],
426444
eta_loc[:, [j]],
427445
xh_scale,
@@ -452,12 +470,16 @@ def _b_step_loop(
452470
data = data.todense()
453471

454472
ll = self.model.ll_handle()
473+
lb, ub = self.model.param_bounds(dtype=data.dtype)
474+
lb_bracket = np.max([lb["b_var"], b_var[0, j] - 20])
475+
ub_bracket = np.min([ub["b_var"], b_var[0, j] + 20])
455476

456477
def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
478+
x = np.clip(np.array([[x]]), lb["b_var"], ub["b_var"])
457479
return - np.sum(ll(
458480
data_j,
459481
eta_loc_j,
460-
np.array([[x]]),
482+
x,
461483
xh_scale_j
462484
))
463485

@@ -466,7 +488,7 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
466488
args=(data, eta_loc, xh_scale),
467489
maxiter=max_iter,
468490
tol=ftol,
469-
brack=(-5, 5),
491+
brack=(lb_bracket, ub_bracket),
470492
full_output=False
471493
)
472494
else:
@@ -489,8 +511,7 @@ def finalize(self):
489511
transfers relevant attributes.
490512
"""
491513
# Read from numpy-IRLS estimator specific model:
492-
493-
self._hessian = self.model.hessian.compute()
514+
self._hessian = - self.model.fim.compute()
494515
self._fisher_inv = np.linalg.inv(- self._hessian)
495516
self._jacobian = np.sum(np.abs(self.model.jac.compute() / self.model.x.shape[0]), axis=1)
496517
self._log_likelihood = self.model.ll_byfeature.compute()

0 commit comments

Comments
 (0)