Skip to content

Commit 94dae69

Browse files
added a_only and b_only gene wise train_ops
1 parent 4498ef9 commit 94dae69

File tree

2 files changed

+97
-7
lines changed

2 files changed

+97
-7
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def __init__(
332332
learning_rate=learning_rate,
333333
global_step=global_step,
334334
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
335-
name="batch_trainers"
335+
name="batch_trainers_bygene"
336336
)
337337
batch_trainers_a_only = train_utils.MultiTrainer(
338338
gradients=[
@@ -349,6 +349,27 @@ def __init__(
349349
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
350350
name="batch_trainers_a_only"
351351
)
352+
batch_trainers_a_only_bygene = train_utils.MultiTrainer(
353+
gradients=[
354+
(
355+
tf.concat([
356+
tf.concat([
357+
tf.gradients(batch_model.norm_neg_log_likelihood,
358+
model_vars.a_by_gene[i])[0]
359+
if not model_vars.converged[i]
360+
else tf.zeros([model_vars.a.shape[0], 1])
361+
for i in range(model_vars.a.shape[1])
362+
], axis=1),
363+
tf.zeros_like(model_vars.b)
364+
], axis=0),
365+
model_vars.params
366+
),
367+
],
368+
learning_rate=learning_rate,
369+
global_step=global_step,
370+
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
371+
name="batch_trainers_a_only_bygene"
372+
)
352373
batch_trainers_b_only = train_utils.MultiTrainer(
353374
gradients=[
354375
(
@@ -364,6 +385,27 @@ def __init__(
364385
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
365386
name="batch_trainers_b_only"
366387
)
388+
batch_trainers_b_only_bygene = train_utils.MultiTrainer(
389+
gradients=[
390+
(
391+
tf.concat([
392+
tf.zeros_like(model_vars.a),
393+
tf.concat([
394+
tf.gradients(batch_model.norm_neg_log_likelihood,
395+
model_vars.b_by_gene[i])[0]
396+
if not model_vars.converged[i]
397+
else tf.zeros([model_vars.b.shape[0], 1])
398+
for i in range(model_vars.b.shape[1])
399+
], axis=1)
400+
], axis=0),
401+
model_vars.params
402+
),
403+
],
404+
learning_rate=learning_rate,
405+
global_step=global_step,
406+
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
407+
name="batch_trainers_b_only_bygene"
408+
)
367409

368410
with tf.name_scope("batch_gradient"):
369411
batch_gradient = batch_trainers.plain_gradient_by_variable(model_vars.params)
@@ -396,7 +438,7 @@ def __init__(
396438
learning_rate=learning_rate,
397439
global_step=global_step,
398440
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
399-
name="full_data_trainers"
441+
name="full_data_trainers_bygene"
400442
)
401443
full_data_trainers_a_only = train_utils.MultiTrainer(
402444
gradients=[
@@ -413,6 +455,27 @@ def __init__(
413455
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
414456
name="full_data_trainers_a_only"
415457
)
458+
full_data_trainers_a_only_bygene = train_utils.MultiTrainer(
459+
gradients=[
460+
(
461+
tf.concat([
462+
tf.concat([
463+
tf.gradients(full_data_model.norm_neg_log_likelihood,
464+
model_vars.a_by_gene[i])[0]
465+
if not model_vars.converged[i]
466+
else tf.zeros([model_vars.a.shape[0], 1])
467+
for i in range(model_vars.a.shape[1])
468+
], axis=1),
469+
tf.zeros_like(model_vars.b)
470+
], axis=0),
471+
model_vars.params
472+
),
473+
],
474+
learning_rate=learning_rate,
475+
global_step=global_step,
476+
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
477+
name="full_data_trainers_a_only_bygene"
478+
)
416479
full_data_trainers_b_only = train_utils.MultiTrainer(
417480
gradients=[
418481
(
@@ -428,6 +491,28 @@ def __init__(
428491
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
429492
name="full_data_trainers_b_only"
430493
)
494+
full_data_trainers_b_only_bygene = train_utils.MultiTrainer(
495+
gradients=[
496+
(
497+
tf.concat([
498+
tf.zeros_like(model_vars.a),
499+
tf.concat([
500+
tf.gradients(full_data_model.norm_neg_log_likelihood,
501+
model_vars.b_by_gene[i])[0]
502+
if not model_vars.converged[i]
503+
else tf.zeros([model_vars.b.shape[0], 1])
504+
for i in range(model_vars.b.shape[1])
505+
], axis=1)
506+
], axis=0),
507+
model_vars.params
508+
),
509+
],
510+
learning_rate=learning_rate,
511+
global_step=global_step,
512+
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
513+
name="full_data_trainers_b_only_bygene"
514+
)
515+
431516
with tf.name_scope("full_gradient"):
432517
# use same gradient as the optimizers
433518
full_gradient = full_data_trainers.plain_gradient_by_variable(model_vars.params)
@@ -595,12 +680,17 @@ def __init__(
595680
self.batch_trainers = batch_trainers
596681
self.batch_trainers_bygene = batch_trainers_bygene
597682
self.batch_trainers_a_only = batch_trainers_a_only
683+
self.batch_trainers_a_only_bygene = batch_trainers_a_only_bygene
598684
self.batch_trainers_b_only = batch_trainers_b_only
685+
self.batch_trainers_b_only_bygene = batch_trainers_b_only_bygene
599686

600687
self.full_data_trainers = full_data_trainers
601688
self.full_data_trainers_bygene = full_data_trainers_bygene
602689
self.full_data_trainers_a_only = full_data_trainers_a_only
690+
self.full_data_trainers_a_only_bygene = full_data_trainers_a_only_bygene
603691
self.full_data_trainers_b_only = full_data_trainers_b_only
692+
self.full_data_trainers_b_only_bygene = full_data_trainers_b_only_bygene
693+
604694
self.global_step = global_step
605695

606696
self.gradient = batch_gradient
@@ -1185,13 +1275,13 @@ def train(self, *args,
11851275
train_op = self.model.batch_trainers.train_op_by_name(optim_algo)
11861276
else:
11871277
if convergence_criteria == "all_converged":
1188-
assert False
1278+
train_op = self.model.batch_trainers_a_only_bygene.train_op_by_name(optim_algo)
11891279
else:
11901280
train_op = self.model.batch_trainers_a_only.train_op_by_name(optim_algo)
11911281
else:
11921282
if train_r:
11931283
if convergence_criteria == "all_converged":
1194-
assert False
1284+
train_op = self.model.batch_trainers_b_only_bygene.train_op_by_name(optim_algo)
11951285
else:
11961286
train_op = self.model.batch_trainers_b_only.train_op_by_name(optim_algo)
11971287
else:
@@ -1212,13 +1302,13 @@ def train(self, *args,
12121302
train_op = self.model.full_data_trainers.train_op_by_name(optim_algo)
12131303
else:
12141304
if convergence_criteria == "all_converged":
1215-
assert False
1305+
train_op = self.model.full_data_trainers_a_only_bygene.train_op_by_name(optim_algo)
12161306
else:
12171307
train_op = self.model.full_data_trainers_a_only.train_op_by_name(optim_algo)
12181308
else:
12191309
if train_r:
12201310
if convergence_criteria == "all_converged":
1221-
assert False
1311+
train_op = self.model.full_data_trainers_b_only_bygene.train_op_by_name(optim_algo)
12221312
else:
12231313
train_op = self.model.full_data_trainers_b_only.train_op_by_name(optim_algo)
12241314
else:

batchglm/unit_test/test_nb_glm_featureconvergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class NB_GLM_Test(unittest.TestCase):
8888
_estims: List[Estimator]
8989

9090
def setUp(self):
91-
self.sim = Simulator(num_observations=1000, num_features=5)
91+
self.sim = Simulator(num_observations=1000, num_features=7)
9292
self.sim.generate()
9393
self._estims = []
9494

0 commit comments

Comments
 (0)