@@ -82,6 +82,11 @@ def train(
8282 """
8383 # Iterate until conditions are fulfilled.
8484 train_step = 0
85+ if self ._train_scale :
86+ if not self ._train_loc :
87+ update_b_freq = 1
88+ else :
89+ update_b_freq = np .inf
8590 epochs_until_b_update = update_b_freq
8691 fully_converged = np .tile (False , self .model .model_vars .n_features )
8792
@@ -97,25 +102,29 @@ def train(
97102 if epochs_until_b_update == 0 :
98103 # Compute update.
99104 idx_update = np .where (np .logical_not (fully_converged ))[0 ]
100- b_step = self .b_step (
101- idx_update = idx_update ,
102- method = method_b ,
103- ftol = ftol_b ,
104- lr = lr_b ,
105- max_iter = max_iter_b ,
106- nproc = nproc
107- )
108- # Perform trial update.
109- self .model .b_var = self .model .b_var + b_step
110- # Reverse update by feature if update leads to worse loss:
111- ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
112- idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
113- if isinstance (self .model .b_var , dask .array .core .Array ):
114- b_var_new = self .model .b_var .compute ()
105+ if self ._train_scale :
106+ b_step = self .b_step (
107+ idx_update = idx_update ,
108+ method = method_b ,
109+ ftol = ftol_b ,
110+ lr = lr_b ,
111+ max_iter = max_iter_b ,
112+ nproc = nproc
113+ )
114+ # Perform trial update.
115+ self .model .b_var = self .model .b_var + b_step
116+ # Reverse update by feature if update leads to worse loss:
117+ ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
118+ idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
119+ if isinstance (self .model .b_var , dask .array .core .Array ):
120+ b_var_new = self .model .b_var .compute ()
121+ else :
122+ b_var_new = self .model .b_var .copy ()
123+ b_var_new [:, idx_bad_step ] = b_var_new [:, idx_bad_step ] - b_step [:, idx_bad_step ]
124+ self .model .b_var = b_var_new
115125 else :
116- b_var_new = self .model .b_var .copy ()
117- b_var_new [:, idx_bad_step ] = b_var_new [:, idx_bad_step ] - b_step [:, idx_bad_step ]
118- self .model .b_var = b_var_new
126+ ll_proposal = ll_current [idx_update ]
127+ idx_bad_step = np .array ([], dtype = np .int32 )
119128 # Update likelihood vector with updated genes based on already evaluated proposal likelihood.
120129 ll_new = ll_current .copy ()
121130 ll_new [idx_update ] = ll_proposal
@@ -126,18 +135,22 @@ def train(
126135 # IWLS step for location model:
127136 # Compute update.
128137 idx_update = self .model .idx_not_converged
129- a_step = self .iwls_step (idx_update = idx_update )
130- # Perform trial update.
131- self .model .a_var = self .model .a_var + a_step
132- # Reverse update by feature if update leads to worse loss:
133- ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
134- idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
135- if isinstance (self .model .b_var , dask .array .core .Array ):
136- a_var_new = self .model .a_var .compute ()
138+ if self ._train_loc :
139+ a_step = self .iwls_step (idx_update = idx_update )
140+ # Perform trial update.
141+ self .model .a_var = self .model .a_var + a_step
142+ # Reverse update by feature if update leads to worse loss:
143+ ll_proposal = - self .model .ll_byfeature_j (j = idx_update ).compute ()
144+ idx_bad_step = idx_update [np .where (ll_proposal > ll_current [idx_update ])[0 ]]
145+ if isinstance (self .model .b_var , dask .array .core .Array ):
146+ a_var_new = self .model .a_var .compute ()
147+ else :
148+ a_var_new = self .model .a_var .copy ()
149+ a_var_new [:, idx_bad_step ] = a_var_new [:, idx_bad_step ] - a_step [:, idx_bad_step ]
150+ self .model .a_var = a_var_new
137151 else :
138- a_var_new = self .model .a_var .copy ()
139- a_var_new [:, idx_bad_step ] = a_var_new [:, idx_bad_step ] - a_step [:, idx_bad_step ]
140- self .model .a_var = a_var_new
152+ ll_proposal = ll_current [idx_update ]
153+ idx_bad_step = np .array ([], dtype = np .int32 )
141154 # Update likelihood vector with updated genes based on already evaluated proposal likelihood.
142155 ll_new = ll_current .copy ()
143156 ll_new [idx_update ] = ll_proposal
@@ -273,10 +286,16 @@ def iwls_step(
273286 invertible = np .where (dask .array .map_blocks (
274287 get_cond_number , a , chunks = a .shape
275288 ).squeeze ().compute () < 1 / sys .float_info .epsilon )[0 ]
276- delta_theta [:, idx_update [invertible ]] = dask .array .map_blocks (
277- np .linalg .solve , a [invertible ], b [invertible , :, None ],
278- chunks = b [invertible , :, None ].shape
279- ).squeeze ().T .compute ()
289+ if len (idx_update [invertible ]) > 1 :
290+ delta_theta [:, idx_update [invertible ]] = dask .array .map_blocks (
291+ np .linalg .solve , a [invertible ], b [invertible , :, None ],
292+ chunks = b [invertible , :, None ].shape
293+ ).squeeze ().T .compute ()
294+ elif len (idx_update [invertible ]) == 1 :
295+ delta_theta [:, idx_update [invertible ]] = np .expand_dims (
296+ np .linalg .solve (a [invertible [0 ]], b [invertible [0 ]]).compute (),
297+ axis = - 1
298+ )
280299 else :
281300 if np .linalg .cond (a .compute (), p = None ) < 1 / sys .float_info .epsilon :
282301 delta_theta [:, idx_update ] = np .expand_dims (
@@ -290,7 +309,7 @@ def iwls_step(
290309 invertible = np .where (np .linalg .cond (a , p = None ) < 1 / sys .float_info .epsilon )[0 ]
291310 delta_theta [:, idx_update [invertible ]] = np .linalg .solve (a [invertible ], b [invertible ]).T
292311 if invertible .shape [0 ] < len (idx_update ):
293- print ("caught %i linalg singular matrix errors" % (len (idx_update ) - invertible .shape [0 ]))
312+ sys . stdout . write ("caught %i linalg singular matrix errors\n " % (len (idx_update ) - invertible .shape [0 ]))
294313 # Via np.linalg.lsts:
295314 #delta_theta[:, idx_update] = np.concatenate([
296315 # np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
@@ -537,14 +556,3 @@ def get_model_container(
537556 input_data
538557 ):
539558 pass
540-
541- @abc .abstractmethod
542- def init_par (
543- self ,
544- input_data ,
545- init_a ,
546- init_b ,
547- init_model
548- ):
549- pass
550-
0 commit comments