Skip to content

Commit 3277fe0

Browse files
Merge pull request #54 from theislab/feature_batching
Feature-wise fitting termination
2 parents f9e5d80 + 94dae69 commit 3277fe0

File tree

5 files changed

+430
-23
lines changed

5 files changed

+430
-23
lines changed

batchglm/train/tf/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ def train(self, *args,
287287
)
288288

289289
tf.logging.info("Step: %d\tloss: %f", train_step, global_loss)
290+
elif convergence_criteria == "all_converged":
291+
theta_current = self.session.run(self.model.model_vars.params)
292+
while np.any(self.model.model_vars.converged == False):
293+
theta_prev = theta_current
294+
train_step, global_loss, _ = self.session.run(
295+
(self.model.global_step, loss, train_op),
296+
feed_dict=feed_dict
297+
)
298+
theta_current = self.session.run(self.model.model_vars.params)
299+
theta_delta = np.abs(theta_prev - theta_current)
300+
self.model.model_vars.converged = np.logical_or( # Only update non-converged.
301+
self.model.model_vars.converged,
302+
np.max(theta_delta, axis=0) < stopping_criteria
303+
)
304+
tf.logging.info(
305+
"Step: %d\tloss: %f\t models converged %i",
306+
train_step,
307+
global_loss,
308+
np.sum(self.model.model_vars.converged).astype("int32")
309+
)
290310
else:
291311
self._train_to_convergence(
292312
loss=loss,

batchglm/train/tf/nb_glm/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ class ModelVars:
225225
a_var: tf.Variable
226226
b_var: tf.Variable
227227
params: tf.Variable
228+
converged: np.ndarray
229+
228230
""" Build tf.Variables to be optimzed and their constraints.
229231
230232
a_var and b_var slices of the tf.Variable params which contains
@@ -309,8 +311,13 @@ def __init__(
309311
axis=0
310312
), name="params")
311313

312-
a_var = params[0:init_a.shape[0]]
313-
b_var = params[init_a.shape[0]:]
314+
params_by_gene = [tf.expand_dims(params[:, i], axis=-1) for i in range(params.shape[1])]
315+
a_by_gene = [x[0:init_a.shape[0],:] for x in params_by_gene]
316+
b_by_gene = [x[init_a.shape[0]:, :] for x in params_by_gene]
317+
a_var = tf.concat(a_by_gene, axis=1)
318+
b_var = tf.concat(b_by_gene, axis=1)
319+
#a_var = params[0:init_a.shape[0]]
320+
#b_var = params[init_a.shape[0]:]
314321

315322
# Define first layer of computation graph on identifiable variables
316323
# to yield dependent set of parameters of model for each location
@@ -334,3 +341,10 @@ def __init__(
334341
self.a_var = a_var
335342
self.b_var = b_var
336343
self.params = params
344+
# Properties to follow gene-wise convergence.
345+
self.params_by_gene = params_by_gene
346+
self.a_by_gene = a_by_gene
347+
self.b_by_gene = b_by_gene
348+
self.converged = np.repeat(a=False, repeats=self.params.shape[1]) # Initialise to non-converged.
349+
self.n_features = self.params.shape[1]
350+

0 commit comments

Comments
 (0)