@@ -56,68 +56,137 @@ def train(
5656 nproc : int = 3 ,
5757 ** kwargs
5858 ):
59+ """
60+ Train GLM.
61+
62+ Convergence decision:
63+ Location and scale model updates are done in separate iterations and are done with different algorithms.
64+ Scale model updates are much less frequent (only every update_b_freq-th iteration) as they are much slower.
65+ During a stretch of update_b_freq number of location model updates between two scale model updates, convergence
66+ of the location model is tracked with self.model.converged. This is re-set after a scale model update, as this
67+ convergence only holds conditioned on a particular scale model value.
68+ Full convergence of a feature wise model is evaluated after each scale model update: If the loss function based
69+ convergence criterium holds across the cumulative updates of the sequence of location updates and last scale
70+ model update, the feature is considered converged. For this, the loss value at the last scale model update is
71+ save in ll_last_b_update. Full convergence is saved in fully_converged.
72+
73+ :param max_steps:
74+ :param method_b:
75+ :param update_b_freq: One over minimum frequency of scale model updates per location model update.
76+ A scale model update will be run at least every update_b_freq number of location model update iterations.
77+ :param ftol_b:
78+ :param lr_b:
79+ :param max_iter_b:
80+ :param nproc:
81+ :param kwargs:
82+ :return:
83+ """
5984 # Iterate until conditions are fulfilled.
6085 train_step = 0
61- delayed_converged = np .tile (False , self .model .model_vars .n_features )
86+ epochs_until_b_update = update_b_freq
87+ fully_converged = np .tile (False , self .model .model_vars .n_features )
6288
6389 ll_current = - self .model .ll_byfeature .compute ()
90+ ll_last_b_update = ll_current .copy ()
6491 #logging.getLogger("batchglm").info(
6592 sys .stdout .write ("iter %i: ll=%f\n " % (0 , np .sum (ll_current )))
66- while np .any (np .logical_not (delayed_converged )) and \
93+ while np .any (np .logical_not (fully_converged )) and \
6794 train_step < max_steps :
6895 t0 = time .time ()
69- # Update parameters:
7096 # Line search step for scale model:
71- if train_step % update_b_freq == 0 and train_step > 0 :
72- if isinstance (self .model .b_var , dask .array .core .Array ):
73- b_var_cache = self .model .b_var .compute ()
74- else :
75- b_var_cache = self .model .b_var .copy ()
76- self .model .b_var = self .b_step (
77- idx = np .where (np .logical_not (delayed_converged ))[0 ],
97+ # Run this update every update_b_freq iterations.
98+ if epochs_until_b_update == 0 :
99+ # Compute update.
100+ idx_update = np .where (np .logical_not (fully_converged ))[0 ]
101+ b_step = self .b_step (
102+ idx_update = idx_update ,
78103 method = method_b ,
79104 ftol = ftol_b ,
80105 lr = lr_b ,
81106 max_iter = max_iter_b ,
82107 nproc = nproc
83108 )
109+ # Perform trial update.
110+ self .model .b_var = self .model .b_var + b_step
84111 # Reverse update by feature if update leads to worse loss:
85- ll_proposal = - self .model .ll_byfeature .compute ()
112+ ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
113+ idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
86114 if isinstance (self .model .b_var , dask .array .core .Array ):
87115 b_var_new = self .model .b_var .compute ()
88116 else :
89117 b_var_new = self .model .b_var .copy ()
90- b_var_new [:, ll_proposal > ll_current ] = b_var_cache [:, ll_proposal > ll_current ]
118+ b_var_new [:, idx_bad_step ] = b_var_new [:, idx_bad_step ] - b_step [:, idx_bad_step ]
91119 self .model .b_var = b_var_new
92- delayed_converged = self .model .converged .copy ()
93- # IWLS step for location model:
94- if np .any (np .logical_not (self .model .converged )) or train_step % update_b_freq == 0 and train_step > 0 :
95- self .model .a_var = self .model .a_var + self .iwls_step ()
96- # Evaluate convergence
97- ll_previous = ll_current
98- ll_current = - self .model .ll_byfeature .compute ()
99- converged_f = np .logical_or (
100- ll_previous < ll_current , # loss gets worse
101- np .abs (ll_previous - ll_current ) / np .maximum ( # relative decrease in loss is too small
102- np .nextafter (0 , np .inf , dtype = ll_previous .dtype ), # catch division by zero
103- np .abs (ll_previous )
104- ) < pkg_constants .LLTOL_BY_FEATURE ,
105- )
106- # Location model convergence status has to be updated if b model was updated
107- if train_step % update_b_freq == 0 and train_step > 0 :
108- self .model .converged = converged_f
109- delayed_converged = converged_f
120+ # Update likelihood vector with updated genes based on already evaluated proposal likelihood.
121+ ll_new = ll_current .copy ()
122+ ll_new [idx_update ] = ll_proposal
123+ ll_new [idx_bad_step ] = ll_current [idx_bad_step ]
124+ # Reset b model update counter.
125+ epochs_until_b_update = update_b_freq
126+ else :
127+ # IWLS step for location model:
128+ # Compute update.
129+ idx_update = self .model .idx_not_converged
130+ a_step = self .iwls_step (idx_update = idx_update )
131+ # Perform trial update.
132+ self .model .a_var = self .model .a_var + a_step
133+ # Reverse update by feature if update leads to worse loss:
134+ ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
135+ idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
136+ if isinstance (self .model .b_var , dask .array .core .Array ):
137+ a_var_new = self .model .a_var .compute ()
138+ else :
139+ a_var_new = self .model .a_var .copy ()
140+ a_var_new [:, idx_bad_step ] = a_var_new [:, idx_bad_step ] - a_step [:, idx_bad_step ]
141+ self .model .a_var = a_var_new
142+ # Update likelihood vector with updated genes based on already evaluated proposal likelihood.
143+ ll_new = ll_current .copy ()
144+ ll_new [idx_update ] = ll_proposal
145+ ll_new [idx_bad_step ] = ll_current [idx_bad_step ]
146+ # Update epoch counter of a updates until next b update:
147+ epochs_until_b_update -= 1
148+
149+ # Evaluate and update convergence:
150+ ll_previous = ll_current
151+ ll_current = ll_new
152+ if epochs_until_b_update == update_b_freq : # b step update was executed.
153+ # Update terminal convergence in fully_converged and intermediate convergence in self.model.converged.
154+ converged_f = np .logical_or (
155+ ll_last_b_update < ll_current , # loss gets worse
156+ np .abs (ll_last_b_update - ll_current ) / np .maximum ( # relative decrease in loss is too small
157+ np .nextafter (0 , np .inf , dtype = ll_previous .dtype ), # catch division by zero
158+ np .abs (ll_last_b_update )
159+ ) < pkg_constants .LLTOL_BY_FEATURE ,
160+ )
161+ self .model .converged = np .logical_or (fully_converged , converged_f )
162+ ll_last_b_update = ll_current .copy ()
163+ fully_converged = self .model .converged .copy ()
110164 else :
165+ # Update intermediate convergence in self.model.converged.
166+ converged_f = np .logical_or (
167+ ll_previous < ll_current , # loss gets worse
168+ np .abs (ll_previous - ll_current ) / np .maximum ( # relative decrease in loss is too small
169+ np .nextafter (0 , np .inf , dtype = ll_previous .dtype ), # catch division by zero
170+ np .abs (ll_previous )
171+ ) < pkg_constants .LLTOL_BY_FEATURE ,
172+ )
111173 self .model .converged = np .logical_or (self .model .converged , converged_f )
174+ if np .all (self .model .converged ):
175+ # All location models are converged. This means that the next update will be b model
176+ # update and all remaining intermediate a model updates can be skipped:
177+ epochs_until_b_update = 0
178+
179+ # Conclude and report iteration.
112180 train_step += 1
113181 #logging.getLogger("batchglm").info(
114182 sys .stdout .write (
115- "iter %s: ll=%f, converged: %.2f%% (location model : %.2f%%), in %.2fsec\n " %
183+ "iter %s: ll=%f, converged: %.2f%% (loc : %.2f%%, scale update: %s ), in %.2fsec\n " %
116184 (
117185 (" " if train_step < 10 else "" ) + (" " if train_step < 100 else "" ) + str (train_step ),
118186 np .sum (ll_current ),
119- np .mean (delayed_converged )* 100 ,
187+ np .mean (fully_converged )* 100 ,
120188 np .mean (self .model .converged ) * 100 ,
189+ str (epochs_until_b_update == update_b_freq ),
121190 time .time ()- t0
122191 )
123192 )
@@ -143,6 +212,7 @@ def a_step_gd(
143212 :return:
144213 """
145214 iter = 0
215+ a_var_old = self .model .a_var .compute ()
146216 converged = np .tile (True , self .model .model_vars .n_features )
147217 converged [idx ] = False
148218 ll_current = - self .model .ll_byfeature .compute ()
@@ -170,15 +240,18 @@ def a_step_gd(
170240 np .mean (converged ) * 100
171241 )
172242 )
173- return self .model .a_var .compute ()
243+ return self .model .a_var .compute () - a_var_old
174244
175- def iwls_step (self ) -> np .ndarray :
245+ def iwls_step (
246+ self ,
247+ idx_update : np .ndarray
248+ ) -> np .ndarray :
176249 """
177250
178251 :return: (inferred param x features)
179252 """
180- w = self .model .fim_weight_j (j = self . model . idx_not_converged ) # (observations x features)
181- ybar = self .model .ybar_j (j = self . model . idx_not_converged ) # (observations x features)
253+ w = self .model .fim_weight_j (j = idx_update ) # (observations x features)
254+ ybar = self .model .ybar_j (j = idx_update ) # (observations x features)
182255 # Translate to problem of form ax = b for each feature:
183256 # (in the following, X=design and Y=counts)
184257 # a=X^T*W*X: ([features] x inferred param)
@@ -188,7 +261,7 @@ def iwls_step(self) -> np.ndarray:
188261 xhw = np .einsum ('ob,of->fob' , xh , w )
189262 a = np .einsum ('fob,oc->fbc' , xhw , xh )
190263 b = np .einsum ('fob,of->fb' , xhw , ybar )
191- # Via np.linalg.solve:
264+
192265 delta_theta = np .zeros_like (self .model .a_var )
193266 if isinstance (delta_theta , dask .array .core .Array ):
194267 delta_theta = delta_theta .compute ()
@@ -197,31 +270,31 @@ def iwls_step(self) -> np.ndarray:
197270 # Have to use a workaround to solve problems in parallel in dask here. This workaround does
198271 # not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
199272 if a .shape [0 ] != 1 :
200- delta_theta [:, self . model . idx_not_converged ] = dask .array .map_blocks (
273+ delta_theta [:, idx_update ] = dask .array .map_blocks (
201274 np .linalg .solve , a , b [:, :, None ], chunks = b [:, :, None ].shape
202275 ).squeeze ().T .compute ()
203276 else :
204- delta_theta [:, self . model . idx_not_converged ] = np .expand_dims (
277+ delta_theta [:, idx_update ] = np .expand_dims (
205278 np .linalg .solve (a [0 ], b [0 ]).compute (),
206279 axis = - 1
207280 )
208281 else :
209- delta_theta [:, self . model . idx_not_converged ] = np .linalg .solve (a , b ).T
282+ delta_theta [:, idx_update ] = np .linalg .solve (a , b ).T
210283 # Via np.linalg.lsts:
211- #delta_theta[:, self.idx_not_converged ] = np.concatenate([
284+ #delta_theta[:, idx_update ] = np.concatenate([
212285 # np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
213- # for i in self.idx_not_converged )
286+ # for i in idx_update )
214287 #], axis=-1)
215288 # Via np.linalg.inv:
216- # #delta_theta[:, self.idx_not_converged ] = np.concatenate([
289+ # #delta_theta[:, idx_update ] = np.concatenate([
217290 # np.expand_dims(np.matmul(np.linalg.inv(a[i, :, :]), b[i, :]), axis=-1)
218- # for i in self.idx_not_converged )
291+ # for i in idx_update )
219292 #], axis=-1)
220293 return delta_theta
221294
222295 def b_step (
223296 self ,
224- idx : np .ndarray ,
297+ idx_update : np .ndarray ,
225298 method : str ,
226299 ftol : float ,
227300 lr : float ,
@@ -234,14 +307,14 @@ def b_step(
234307 """
235308 if method .lower () in ["gd" ]:
236309 return self ._b_step_gd (
237- idx = idx ,
310+ idx_update = idx_update ,
238311 ftol = ftol ,
239312 lr = lr ,
240313 max_iter = max_iter
241314 )
242315 else :
243316 return self ._b_step_loop (
244- idx = idx ,
317+ idx_update = idx_update ,
245318 method = method ,
246319 ftol = ftol ,
247320 max_iter = max_iter ,
@@ -250,7 +323,7 @@ def b_step(
250323
251324 def _b_step_gd (
252325 self ,
253- idx : np .ndarray ,
326+ idx_update : np .ndarray ,
254327 ftol : float ,
255328 max_iter : int ,
256329 lr : float
@@ -260,8 +333,9 @@ def _b_step_gd(
260333 :return:
261334 """
262335 iter = 0
336+ b_var_old = self .model .b_var .compute ()
263337 converged = np .tile (True , self .model .model_vars .n_features )
264- converged [idx ] = False
338+ converged [idx_update ] = False
265339 ll_current = - self .model .ll_byfeature .compute ()
266340 while np .any (np .logical_not (converged )) and iter < max_iter :
267341 idx_to_update = np .where (np .logical_not (converged ))[0 ]
@@ -290,7 +364,7 @@ def _b_step_gd(
290364 np .mean (converged ) * 100
291365 )
292366 )
293- return self .model .b_var .compute ()
367+ return self .model .b_var .compute () - b_var_old
294368
295369 def optim_handle (
296370 self ,
@@ -300,8 +374,8 @@ def optim_handle(
300374 max_iter ,
301375 ftol
302376 ):
303-
304- if isinstance (data_j , sparse ._coo . core . COO ) or isinstance (data_j , scipy .sparse .csr_matrix ):
377+ # Need to supply dense numpy array to scipy optimize:
378+ if isinstance (data_j , sparse .COO ) or isinstance (data_j , scipy .sparse .csr_matrix ):
305379 data_j = data_j .todense ()
306380 if len (data_j .shape ) == 1 :
307381 data_j = np .expand_dims (data_j , axis = - 1 )
@@ -323,7 +397,7 @@ def cost_b_var(x, data_jj, eta_loc_jj, xh_scale_jj):
323397
324398 def _b_step_loop (
325399 self ,
326- idx : np .ndarray ,
400+ idx_update : np .ndarray ,
327401 method : str ,
328402 max_iter : int ,
329403 ftol : float ,
@@ -334,15 +408,13 @@ def _b_step_loop(
334408 :return:
335409 """
336410 x0 = - 10
337-
338- if isinstance (self .model .b_var , dask .array .core .Array ):
339- b_var_new = self .model .b_var .compute ()
340- else :
341- b_var_new = self .model .b_var .copy ()
411+ delta_theta = np .zeros_like (self .model .b_var )
412+ if isinstance (delta_theta , dask .array .core .Array ):
413+ delta_theta = delta_theta .compute ()
342414
343415 xh_scale = np .matmul (self .model .design_scale , self .model .constraints_scale ).compute ()
344416 if nproc > 1 :
345- sys .stdout .write ('\r Fitting %i dispersion models: (progress not available with multiprocessing)' % len (idx ))
417+ sys .stdout .write ('\r Fitting %i dispersion models: (progress not available with multiprocessing)' % len (idx_update ))
346418 sys .stdout .flush ()
347419 with multiprocessing .Pool (processes = nproc ) as pool :
348420 x = self .x .compute ()
@@ -355,27 +427,28 @@ def _b_step_loop(
355427 xh_scale ,
356428 max_iter ,
357429 ftol
358- ) for j in idx ]
430+ ) for j in idx_update ]
359431 )
360432 pool .close ()
361- b_var_new [0 , idx ] = np .array ([x [0 ] for x in results ])
433+ delta_theta [0 , idx_update ] = np .array ([x [0 ] for x in results ])
362434 sys .stdout .write ('\r ' )
363435 sys .stdout .flush ()
364436 else :
365437 t0 = time .time ()
366- for i , j in enumerate (idx ):
438+ for i , j in enumerate (idx_update ):
367439 sys .stdout .write (
368440 '\r Fitting dispersion models: %.2f%% in %.2fsec' %
369441 (
370- np .round (i / len (idx ) * 100. , 2 ),
442+ np .round (i / len (idx_update ) * 100. , 2 ),
371443 time .time () - t0
372444 )
373445 )
374446 sys .stdout .flush ()
375447 if method .lower () == "brent" :
376448 eta_loc = self .model .eta_loc_j (j = j ).compute ()
377449 data = self .x [:, [j ]].compute ()
378- if isinstance (data , sparse ._coo .core .COO ) or isinstance (data , scipy .sparse .csr_matrix ):
450+ # Need to supply dense numpy array to scipy optimize:
451+ if isinstance (data , sparse .COO ) or isinstance (data , scipy .sparse .csr_matrix ):
379452 data = data .todense ()
380453
381454 ll = self .model .ll_handle ()
@@ -388,7 +461,7 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
388461 xh_scale_j
389462 ))
390463
391- b_var_new [0 , j ] = scipy .optimize .brent (
464+ delta_theta [0 , j ] = scipy .optimize .brent (
392465 func = cost_b_var ,
393466 args = (data , eta_loc , xh_scale ),
394467 maxiter = max_iter ,
@@ -400,7 +473,12 @@ def cost_b_var(x, data_j, eta_loc_j, xh_scale_j):
400473 raise ValueError ("method %s not recognized" % method )
401474 sys .stdout .write ('\r ' )
402475 sys .stdout .flush ()
403- return b_var_new
476+
477+ if isinstance (self .model .b_var , dask .array .core .Array ):
478+ delta_theta [:, idx_update ] = delta_theta [:, idx_update ] - self .model .b_var .compute ()[:, idx_update ]
479+ else :
480+ delta_theta [:, idx_update ] = delta_theta [:, idx_update ] - self .model .b_var .copy ()[:, idx_update ]
481+ return delta_theta
404482
405483 def finalize (self ):
406484 """
0 commit comments