@@ -332,7 +332,7 @@ def __init__(
332332 learning_rate = learning_rate ,
333333 global_step = global_step ,
334334 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
335- name = "batch_trainers "
335+ name = "batch_trainers_bygene "
336336 )
337337 batch_trainers_a_only = train_utils .MultiTrainer (
338338 gradients = [
@@ -349,6 +349,27 @@ def __init__(
349349 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
350350 name = "batch_trainers_a_only"
351351 )
352+ batch_trainers_a_only_bygene = train_utils .MultiTrainer (
353+ gradients = [
354+ (
355+ tf .concat ([
356+ tf .concat ([
357+ tf .gradients (batch_model .norm_neg_log_likelihood ,
358+ model_vars .a_by_gene [i ])[0 ]
359+ if not model_vars .converged [i ]
360+ else tf .zeros ([model_vars .a .shape [0 ], 1 ])
361+ for i in range (model_vars .a .shape [1 ])
362+ ], axis = 1 ),
363+ tf .zeros_like (model_vars .b )
364+ ], axis = 0 ),
365+ model_vars .params
366+ ),
367+ ],
368+ learning_rate = learning_rate ,
369+ global_step = global_step ,
370+ apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
371+ name = "batch_trainers_a_only_bygene"
372+ )
352373 batch_trainers_b_only = train_utils .MultiTrainer (
353374 gradients = [
354375 (
@@ -364,6 +385,27 @@ def __init__(
364385 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
365386 name = "batch_trainers_b_only"
366387 )
388+ batch_trainers_b_only_bygene = train_utils .MultiTrainer (
389+ gradients = [
390+ (
391+ tf .concat ([
392+ tf .zeros_like (model_vars .a ),
393+ tf .concat ([
394+ tf .gradients (batch_model .norm_neg_log_likelihood ,
395+ model_vars .b_by_gene [i ])[0 ]
396+ if not model_vars .converged [i ]
397+ else tf .zeros ([model_vars .b .shape [0 ], 1 ])
398+ for i in range (model_vars .b .shape [1 ])
399+ ], axis = 1 )
400+ ], axis = 0 ),
401+ model_vars .params
402+ ),
403+ ],
404+ learning_rate = learning_rate ,
405+ global_step = global_step ,
406+ apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
407+ name = "batch_trainers_b_only_bygene"
408+ )
367409
368410 with tf .name_scope ("batch_gradient" ):
369411 batch_gradient = batch_trainers .plain_gradient_by_variable (model_vars .params )
@@ -396,7 +438,7 @@ def __init__(
396438 learning_rate = learning_rate ,
397439 global_step = global_step ,
398440 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
399- name = "full_data_trainers "
441+ name = "full_data_trainers_bygene "
400442 )
401443 full_data_trainers_a_only = train_utils .MultiTrainer (
402444 gradients = [
@@ -413,6 +455,27 @@ def __init__(
413455 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
414456 name = "full_data_trainers_a_only"
415457 )
458+ full_data_trainers_a_only_bygene = train_utils .MultiTrainer (
459+ gradients = [
460+ (
461+ tf .concat ([
462+ tf .concat ([
463+ tf .gradients (full_data_model .norm_neg_log_likelihood ,
464+ model_vars .a_by_gene [i ])[0 ]
465+ if not model_vars .converged [i ]
466+ else tf .zeros ([model_vars .a .shape [0 ], 1 ])
467+ for i in range (model_vars .a .shape [1 ])
468+ ], axis = 1 ),
469+ tf .zeros_like (model_vars .b )
470+ ], axis = 0 ),
471+ model_vars .params
472+ ),
473+ ],
474+ learning_rate = learning_rate ,
475+ global_step = global_step ,
476+ apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
477+ name = "full_data_trainers_a_only_bygene"
478+ )
416479 full_data_trainers_b_only = train_utils .MultiTrainer (
417480 gradients = [
418481 (
@@ -428,6 +491,28 @@ def __init__(
428491 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
429492 name = "full_data_trainers_b_only"
430493 )
494+ full_data_trainers_b_only_bygene = train_utils .MultiTrainer (
495+ gradients = [
496+ (
497+ tf .concat ([
498+ tf .zeros_like (model_vars .a ),
499+ tf .concat ([
500+ tf .gradients (full_data_model .norm_neg_log_likelihood ,
501+ model_vars .b_by_gene [i ])[0 ]
502+ if not model_vars .converged [i ]
503+ else tf .zeros ([model_vars .b .shape [0 ], 1 ])
504+ for i in range (model_vars .b .shape [1 ])
505+ ], axis = 1 )
506+ ], axis = 0 ),
507+ model_vars .params
508+ ),
509+ ],
510+ learning_rate = learning_rate ,
511+ global_step = global_step ,
512+ apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
513+ name = "full_data_trainers_b_only_bygene"
514+ )
515+
431516 with tf .name_scope ("full_gradient" ):
432517 # use same gradient as the optimizers
433518 full_gradient = full_data_trainers .plain_gradient_by_variable (model_vars .params )
@@ -595,12 +680,17 @@ def __init__(
595680 self .batch_trainers = batch_trainers
596681 self .batch_trainers_bygene = batch_trainers_bygene
597682 self .batch_trainers_a_only = batch_trainers_a_only
683+ self .batch_trainers_a_only_bygene = batch_trainers_a_only_bygene
598684 self .batch_trainers_b_only = batch_trainers_b_only
685+ self .batch_trainers_b_only_bygene = batch_trainers_b_only_bygene
599686
600687 self .full_data_trainers = full_data_trainers
601688 self .full_data_trainers_bygene = full_data_trainers_bygene
602689 self .full_data_trainers_a_only = full_data_trainers_a_only
690+ self .full_data_trainers_a_only_bygene = full_data_trainers_a_only_bygene
603691 self .full_data_trainers_b_only = full_data_trainers_b_only
692+ self .full_data_trainers_b_only_bygene = full_data_trainers_b_only_bygene
693+
604694 self .global_step = global_step
605695
606696 self .gradient = batch_gradient
@@ -1185,13 +1275,13 @@ def train(self, *args,
11851275 train_op = self .model .batch_trainers .train_op_by_name (optim_algo )
11861276 else :
11871277 if convergence_criteria == "all_converged" :
1188- assert False
1278+ train_op = self . model . batch_trainers_a_only_bygene . train_op_by_name ( optim_algo )
11891279 else :
11901280 train_op = self .model .batch_trainers_a_only .train_op_by_name (optim_algo )
11911281 else :
11921282 if train_r :
11931283 if convergence_criteria == "all_converged" :
1194- assert False
1284+ train_op = self . model . batch_trainers_b_only_bygene . train_op_by_name ( optim_algo )
11951285 else :
11961286 train_op = self .model .batch_trainers_b_only .train_op_by_name (optim_algo )
11971287 else :
@@ -1212,13 +1302,13 @@ def train(self, *args,
12121302 train_op = self .model .full_data_trainers .train_op_by_name (optim_algo )
12131303 else :
12141304 if convergence_criteria == "all_converged" :
1215- assert False
1305+ train_op = self . model . full_data_trainers_a_only_bygene . train_op_by_name ( optim_algo )
12161306 else :
12171307 train_op = self .model .full_data_trainers_a_only .train_op_by_name (optim_algo )
12181308 else :
12191309 if train_r :
12201310 if convergence_criteria == "all_converged" :
1221- assert False
1311+ train_op = self . model . full_data_trainers_b_only_bygene . train_op_by_name ( optim_algo )
12221312 else :
12231313 train_op = self .model .full_data_trainers_b_only .train_op_by_name (optim_algo )
12241314 else :
0 commit comments