Skip to content

Commit 8dcd460

Browse files
improved numpy backend convergence decision reduced unnecessary likelihood evaluations
1 parent b57229f commit 8dcd460

File tree

1 file changed

+144
-66
lines changed

1 file changed

+144
-66
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 144 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -56,68 +56,137 @@ def train(
5656
nproc: int = 3,
5757
**kwargs
5858
):
59+
"""
60+
Train GLM.
61+
62+
Convergence decision:
63+
Location and scale model updates are done in separate iterations and are done with different algorithms.
64+
Scale model updates are much less frequent (only every update_b_freq-th iteration) as they are much slower.
65+
During a stretch of update_b_freq number of location model updates between two scale model updates, convergence
66+
of the location model is tracked with self.model.converged. This is re-set after a scale model update, as this
67+
convergence only holds conditioned on a particular scale model value.
68+
Full convergence of a feature wise model is evaluated after each scale model update: If the loss function based
69+
convergence criterium holds across the cumulative updates of the sequence of location updates and last scale
70+
model update, the feature is considered converged. For this, the loss value at the last scale model update is
71+
save in ll_last_b_update. Full convergence is saved in fully_converged.
72+
73+
:param max_steps:
74+
:param method_b:
75+
:param update_b_freq: One over minimum frequency of scale model updates per location model update.
76+
A scale model update will be run at least every update_b_freq number of location model update iterations.
77+
:param ftol_b:
78+
:param lr_b:
79+
:param max_iter_b:
80+
:param nproc:
81+
:param kwargs:
82+
:return:
83+
"""
5984
# Iterate until conditions are fulfilled.
6085
train_step = 0
61-
delayed_converged = np.tile(False, self.model.model_vars.n_features)
86+
epochs_until_b_update = update_b_freq
87+
fully_converged = np.tile(False, self.model.model_vars.n_features)
6288

6389
ll_current = - self.model.ll_byfeature.compute()
90+
ll_last_b_update = ll_current.copy()
6491
#logging.getLogger("batchglm").info(
6592
sys.stdout.write("iter %i: ll=%f\n" % (0, np.sum(ll_current)))
66-
while np.any(np.logical_not(delayed_converged)) and \
93+
while np.any(np.logical_not(fully_converged)) and \
6794
train_step < max_steps:
6895
t0 = time.time()
69-
# Update parameters:
7096
# Line search step for scale model:
71-
if train_step % update_b_freq == 0 and train_step > 0:
72-
if isinstance(self.model.b_var, dask.array.core.Array):
73-
b_var_cache = self.model.b_var.compute()
74-
else:
75-
b_var_cache = self.model.b_var.copy()
76-
self.model.b_var = self.b_step(
77-
idx=np.where(np.logical_not(delayed_converged))[0],
97+
# Run this update every update_b_freq iterations.
98+
if epochs_until_b_update == 0:
99+
# Compute update.
100+
idx_update = np.where(np.logical_not(fully_converged))[0]
101+
b_step = self.b_step(
102+
idx_update=idx_update,
78103
method=method_b,
79104
ftol=ftol_b,
80105
lr=lr_b,
81106
max_iter=max_iter_b,
82107
nproc=nproc
83108
)
109+
# Perform trial update.
110+
self.model.b_var = self.model.b_var + b_step
84111
# Reverse update by feature if update leads to worse loss:
85-
ll_proposal = - self.model.ll_byfeature.compute()
112+
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
113+
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
86114
if isinstance(self.model.b_var, dask.array.core.Array):
87115
b_var_new = self.model.b_var.compute()
88116
else:
89117
b_var_new = self.model.b_var.copy()
90-
b_var_new[:, ll_proposal > ll_current] = b_var_cache[:, ll_proposal > ll_current]
118+
b_var_new[:, idx_bad_step] = b_var_new[:, idx_bad_step] - b_step[:, idx_bad_step]
91119
self.model.b_var = b_var_new
92-
delayed_converged = self.model.converged.copy()
93-
# IWLS step for location model:
94-
if np.any(np.logical_not(self.model.converged)) or train_step % update_b_freq == 0 and train_step > 0:
95-
self.model.a_var = self.model.a_var + self.iwls_step()
96-
# Evaluate convergence
97-
ll_previous = ll_current
98-
ll_current = - self.model.ll_byfeature.compute()
99-
converged_f = np.logical_or(
100-
ll_previous < ll_current, # loss gets worse
101-
np.abs(ll_previous - ll_current) / np.maximum( # relative decrease in loss is too small
102-
np.nextafter(0, np.inf, dtype=ll_previous.dtype), # catch division by zero
103-
np.abs(ll_previous)
104-
) < pkg_constants.LLTOL_BY_FEATURE,
105-
)
106-
# Location model convergence status has to be updated if b model was updated
107-
if train_step % update_b_freq == 0 and train_step > 0:
108-
self.model.converged = converged_f
109-
delayed_converged = converged_f
120+
# Update likelihood vector with updated genes based on already evaluated proposal likelihood.
121+
ll_new = ll_current.copy()
122+
ll_new[idx_update] = ll_proposal
123+
ll_new[idx_bad_step] = ll_current[idx_bad_step]
124+
# Reset b model update counter.
125+
epochs_until_b_update = update_b_freq
126+
else:
127+
# IWLS step for location model:
128+
# Compute update.
129+
idx_update = self.model.idx_not_converged
130+
a_step = self.iwls_step(idx_update=idx_update)
131+
# Perform trial update.
132+
self.model.a_var = self.model.a_var + a_step
133+
# Reverse update by feature if update leads to worse loss:
134+
ll_proposal = - self.model.ll_byfeature_j(j=idx_update).compute()
135+
idx_bad_step = idx_update[np.where(ll_proposal > ll_current[idx_update])[0]]
136+
if isinstance(self.model.b_var, dask.array.core.Array):
137+
a_var_new = self.model.a_var.compute()
138+
else:
139+
a_var_new = self.model.a_var.copy()
140+
a_var_new[:, idx_bad_step] = a_var_new[:, idx_bad_step] - a_step[:, idx_bad_step]
141+
self.model.a_var = a_var_new
142+
# Update likelihood vector with updated genes based on already evaluated proposal likelihood.
143+
ll_new = ll_current.copy()
144+
ll_new[idx_update] = ll_proposal
145+
ll_new[idx_bad_step] = ll_current[idx_bad_step]
146+
# Update epoch counter of a updates until next b update:
147+
epochs_until_b_update -= 1
148+
149+
# Evaluate and update convergence:
150+
ll_previous = ll_current
151+
ll_current = ll_new
152+
if epochs_until_b_update == update_b_freq: # b step update was executed.
153+
# Update terminal convergence in fully_converged and intermediate convergence in self.model.converged.
154+
converged_f = np.logical_or(
155+
ll_last_b_update < ll_current, # loss gets worse
156+
np.abs(ll_last_b_update - ll_current) / np.maximum( # relative decrease in loss is too small
157+
np.nextafter(0, np.inf, dtype=ll_previous.dtype), # catch division by zero
158+
np.abs(ll_last_b_update)
159+
) < pkg_constants.LLTOL_BY_FEATURE,
160+
)
161+
self.model.converged = np.logical_or(fully_converged, converged_f)
162+
ll_last_b_update = ll_current.copy()
163+
fully_converged = self.model.converged.copy()
110164
else:
165+
# Update intermediate convergence in self.model.converged.
166+
converged_f = np.logical_or(
167+
ll_previous < ll_current, # loss gets worse
168+
np.abs(ll_previous - ll_current) / np.maximum( # relative decrease in loss is too small
169+
np.nextafter(0, np.inf, dtype=ll_previous.dtype), # catch division by zero
170+
np.abs(ll_previous)
171+
) < pkg_constants.LLTOL_BY_FEATURE,
172+
)
111173
self.model.converged = np.logical_or(self.model.converged, converged_f)
174+
if np.all(self.model.converged):
175+
# All location models are converged. This means that the next update will be b model
176+
# update and all remaining intermediate a model updates can be skipped:
177+
epochs_until_b_update = 0
178+
179+
# Conclude and report iteration.
112180
train_step += 1
113181
#logging.getLogger("batchglm").info(
114182
sys.stdout.write(
115-
"iter %s: ll=%f, converged: %.2f%% (location model: %.2f%%), in %.2fsec\n" %
183+
"iter %s: ll=%f, converged: %.2f%% (loc: %.2f%%, scale update: %s), in %.2fsec\n" %
116184
(
117185
(" " if train_step < 10 else "") + (" " if train_step < 100 else "") + str(train_step),
118186
np.sum(ll_current),
119-
np.mean(delayed_converged)*100,
187+
np.mean(fully_converged)*100,
120188
np.mean(self.model.converged) * 100,
189+
str(epochs_until_b_update == update_b_freq),
121190
time.time()-t0
122191
)
123192
)
@@ -143,6 +212,7 @@ def a_step_gd(
143212
:return:
144213
"""
145214
iter = 0
215+
a_var_old = self.model.a_var.compute()
146216
converged = np.tile(True, self.model.model_vars.n_features)
147217
converged[idx] = False
148218
ll_current = - self.model.ll_byfeature.compute()
@@ -170,15 +240,18 @@ def a_step_gd(
170240
np.mean(converged) * 100
171241
)
172242
)
173-
return self.model.a_var.compute()
243+
return self.model.a_var.compute() - a_var_old
174244

175-
def iwls_step(self) -> np.ndarray:
245+
def iwls_step(
246+
self,
247+
idx_update: np.ndarray
248+
) -> np.ndarray:
176249
"""
177250
178251
:return: (inferred param x features)
179252
"""
180-
w = self.model.fim_weight_j(j=self.model.idx_not_converged) # (observations x features)
181-
ybar = self.model.ybar_j(j=self.model.idx_not_converged) # (observations x features)
253+
w = self.model.fim_weight_j(j=idx_update) # (observations x features)
254+
ybar = self.model.ybar_j(j=idx_update) # (observations x features)
182255
# Translate to problem of form ax = b for each feature:
183256
# (in the following, X=design and Y=counts)
184257
# a=X^T*W*X: ([features] x inferred param)
@@ -188,7 +261,7 @@ def iwls_step(self) -> np.ndarray:
188261
xhw = np.einsum('ob,of->fob', xh, w)
189262
a = np.einsum('fob,oc->fbc', xhw, xh)
190263
b = np.einsum('fob,of->fb', xhw, ybar)
191-
# Via np.linalg.solve:
264+
192265
delta_theta = np.zeros_like(self.model.a_var)
193266
if isinstance(delta_theta, dask.array.core.Array):
194267
delta_theta = delta_theta.compute()
@@ -197,31 +270,31 @@ def iwls_step(self) -> np.ndarray:
197270
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
198271
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
199272
if a.shape[0] != 1:
200-
delta_theta[:, self.model.idx_not_converged] = dask.array.map_blocks(
273+
delta_theta[:, idx_update] = dask.array.map_blocks(
201274
np.linalg.solve, a, b[:, :, None], chunks=b[:, :, None].shape
202275
).squeeze().T.compute()
203276
else:
204-
delta_theta[:, self.model.idx_not_converged] = np.expand_dims(
277+
delta_theta[:, idx_update] = np.expand_dims(
205278
np.linalg.solve(a[0], b[0]).compute(),
206279
axis=-1
207280
)
208281
else:
209-
delta_theta[:, self.model.idx_not_converged] = np.linalg.solve(a, b).T
282+
delta_theta[:, idx_update] = np.linalg.solve(a, b).T
210283
# Via np.linalg.lsts:
211-
#delta_theta[:, self.idx_not_converged] = np.concatenate([
284+
#delta_theta[:, idx_update] = np.concatenate([
212285
# np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
213-
# for i in self.idx_not_converged)
286+
# for i in idx_update)
214287
#], axis=-1)
215288
# Via np.linalg.inv:
216-
# #delta_theta[:, self.idx_not_converged] = np.concatenate([
289+
# #delta_theta[:, idx_update] = np.concatenate([
217290
# np.expand_dims(np.matmul(np.linalg.inv(a[i, :, :]), b[i, :]), axis=-1)
218-
# for i in self.idx_not_converged)
291+
# for i in idx_update)
219292
#], axis=-1)
220293
return delta_theta
221294

222295
def b_step(
223296
self,
224-
idx: np.ndarray,
297+
idx_update: np.ndarray,
225298
method: str,
226299
ftol: float,
227300
lr: float,
@@ -234,14 +307,14 @@ def b_step(
234307
"""
235308
if method.lower() in ["gd"]:
236309
return self._b_step_gd(
237-
idx=idx,
310+
idx_update=idx_update,
238311
ftol=ftol,
239312
lr=lr,
240313
max_iter=max_iter
241314
)
242315
else:
243316
return self._b_step_loop(
244-
idx=idx,
317+
idx_update=idx_update,
245318
method=method,
246319
ftol=ftol,
247320
max_iter=max_iter,
@@ -250,7 +323,7 @@ def b_step(
250323

251324
def _b_step_gd(
252325
self,
253-
idx: np.ndarray,
326+
idx_update: np.ndarray,
254327
ftol: float,
255328
max_iter: int,
256329
lr: float
@@ -260,8 +333,9 @@ def _b_step_gd(
260333
:return:
261334
"""
262335
iter = 0
336+
b_var_old = self.model.b_var.compute()
263337
converged = np.tile(True, self.model.model_vars.n_features)
264-
converged[idx] = False
338+
converged[idx_update] = False
265339
ll_current = - self.model.ll_byfeature.compute()
266340
while np.any(np.logical_not(converged)) and iter < max_iter:
267341
idx_to_update = np.where(np.logical_not(converged))[0]
@@ -290,7 +364,7 @@ def _b_step_gd(
290364
np.mean(converged) * 100
291365
)
292366
)
293-
return self.model.b_var.compute()
367+
return self.model.b_var.compute() - b_var_old
294368

295369
def optim_handle(
296370
self,
@@ -300,8 +374,8 @@ def optim_handle(
300374
max_iter,
301375
ftol
302376
):
303-
304-
if isinstance(data_j, sparse._coo.core.COO) or isinstance(data_j, scipy.sparse.csr_matrix):
377+
# Need to supply dense numpy array to scipy optimize:
378+
if isinstance(data_j, sparse.COO) or isinstance(data_j, scipy.sparse.csr_matrix):
305379
data_j = data_j.todense()
306380
if len(data_j.shape) == 1:
307381
data_j = np.expand_dims(data_j, axis=-1)
@@ -323,7 +397,7 @@ def cost_b_var(x, data_jj, eta_loc_jj, xh_scale_jj):
323397

324398
def _b_step_loop(
325399
self,
326-
idx: np.ndarray,
400+
idx_update: np.ndarray,
327401
method: str,
328402
max_iter: int,
329403
ftol: float,
@@ -334,15 +408,13 @@ def _b_step_loop(
334408
:return:
335409
"""
336410
x0 = -10
337-
338-
if isinstance(self.model.b_var, dask.array.core.Array):
339-
b_var_new = self.model.b_var.compute()
340-
else:
341-
b_var_new = self.model.b_var.copy()
411+
delta_theta = np.zeros_like(self.model.b_var)
412+
if isinstance(delta_theta, dask.array.core.Array):
413+
delta_theta = delta_theta.compute()
342414

343415
xh_scale = np.matmul(self.model.design_scale, self.model.constraints_scale).compute()
344416
if nproc > 1:
345-
sys.stdout.write('\rFitting %i dispersion models: (progress not available with multiprocessing)' % len(idx))
417+
sys.stdout.write('\rFitting %i dispersion models: (progress not available with multiprocessing)' % len(idx_update))
346418
sys.stdout.flush()
347419
with multiprocessing.Pool(processes=nproc) as pool:
348420
x = self.x.compute()
@@ -355,27 +427,28 @@ def _b_step_loop(
355427
xh_scale,
356428
max_iter,
357429
ftol
358-
) for j in idx]
430+
) for j in idx_update]
359431
)
360432
pool.close()
361-
b_var_new[0, idx] = np.array([x[0] for x in results])
433+
delta_theta[0, idx_update] = np.array([x[0] for x in results])
362434
sys.stdout.write('\r')
363435
sys.stdout.flush()
364436
else:
365437
t0 = time.time()
366-
for i, j in enumerate(idx):
438+
for i, j in enumerate(idx_update):
367439
sys.stdout.write(
368440
'\rFitting dispersion models: %.2f%% in %.2fsec' %
369441
(
370-
np.round(i/len(idx)*100., 2),
442+
np.round(i / len(idx_update) * 100., 2),
371443
time.time() - t0
372444
)
373445
)
374446
sys.stdout.flush()
375447
if method.lower() == "brent":
376448
eta_loc = self.model.eta_loc_j(j=j).compute()
377449
data = self.x[:, [j]].compute()
378-
if isinstance(data, sparse._coo.core.COO) or isinstance(data, scipy.sparse.csr_matrix):
450+
# Need to supply dense numpy array to scipy optimize:
451+
if isinstance(data, sparse.COO) or isinstance(data, scipy.sparse.csr_matrix):
379452
data = data.todense()
380453

381454
ll = self.model.ll_handle()
@@ -388,7 +461,7 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
388461
xh_scale_j
389462
))
390463

391-
b_var_new[0, j] = scipy.optimize.brent(
464+
delta_theta[0, j] = scipy.optimize.brent(
392465
func=cost_b_var,
393466
args=(data, eta_loc, xh_scale),
394467
maxiter=max_iter,
@@ -400,7 +473,12 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
400473
raise ValueError("method %s not recognized" % method)
401474
sys.stdout.write('\r')
402475
sys.stdout.flush()
403-
return b_var_new
476+
477+
if isinstance(self.model.b_var, dask.array.core.Array):
478+
delta_theta[:, idx_update] = delta_theta[:, idx_update] - self.model.b_var.compute()[:, idx_update]
479+
else:
480+
delta_theta[:, idx_update] = delta_theta[:, idx_update] - self.model.b_var.copy()[:, idx_update]
481+
return delta_theta
404482

405483
def finalize(self):
406484
"""

0 commit comments

Comments
 (0)