@@ -149,7 +149,6 @@ def __init__(
149149 num_design_scale_params ,
150150 graph : tf .Graph = None ,
151151 batch_size = 500 ,
152- feature_batch_size = None ,
153152 init_a = None ,
154153 init_b = None ,
155154 constraints_loc = None ,
@@ -163,7 +162,6 @@ def __init__(
163162 self .num_design_loc_params = num_design_loc_params
164163 self .num_design_scale_params = num_design_scale_params
165164 self .batch_size = batch_size
166- self .feature_batch_size = feature_batch_size
167165
168166 # initial graph elements
169167 with self .graph .as_default ():
@@ -305,11 +303,25 @@ def __init__(
305303 with tf .name_scope ("training" ):
306304 global_step = tf .train .get_or_create_global_step ()
307305
308- # set up trainers for different selections of variables to train
309- # set up multiple optimization algorithms for each trainer
306+ # Set up trainers for different selections of variables to train.
307+ # Set up multiple optimization algorithms for each trainer.
308+ # Note that params is tf.Variable and a, b are tensors as they are
309+ # slices of a variable! Accordingly, the updates are implemented differently.
310310 batch_trainers = train_utils .MultiTrainer (
311- loss = batch_model .norm_neg_log_likelihood ,
312- variables = [model_vars .params ],
311+ #loss=batch_model.norm_neg_log_likelihood, # add only selected features here TODO
312+ #variables=[model_vars.params], # tf.gather(model_vars.params, indices=np.where(model_vars.converged == False)[0], axis=1)],
313+ gradients = [
314+ (
315+ tf .concat ([
316+ tf .gradients (batch_model .norm_neg_log_likelihood ,
317+ model_vars .params_by_gene [i ])[0 ]
318+ if i in np .where (model_vars .converged == False )[0 ]
319+ else tf .zeros ([model_vars .params .shape [0 ], 1 ])
320+ for i in range (model_vars .params .shape [1 ])
321+ ], axis = 1 ),
322+ model_vars .params
323+ ),
324+ ],
313325 learning_rate = learning_rate ,
314326 global_step = global_step ,
315327 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
@@ -354,8 +366,20 @@ def __init__(
354366 # [tf.reduce_sum(tf.abs(grad), axis=0) for (grad, var) in batch_trainers.gradient])
355367
356368 full_data_trainers = train_utils .MultiTrainer (
357- loss = full_data_model .norm_neg_log_likelihood ,
358- variables = [model_vars .params ],
369+ #loss=full_data_model.norm_neg_log_likelihood,
370+ #variables=[tf.gather(model_vars.params, indices=np.where(model_vars.converged == False)[0], axis=1)],
371+ gradients = [
372+ (
373+ tf .concat ([
374+ tf .gradients (full_data_model .norm_neg_log_likelihood ,
375+ model_vars .params_by_gene [i ])[0 ]
376+ if i in np .where (model_vars .converged == False )[0 ]
377+ else tf .zeros ([model_vars .params .shape [0 ], 1 ])
378+ for i in range (model_vars .params .shape [1 ])
379+ ], axis = 1 ),
380+ model_vars .params
381+ ),
382+ ],
359383 learning_rate = learning_rate ,
360384 global_step = global_step ,
361385 apply_gradients = lambda grad : tf .where (tf .is_nan (grad ), tf .zeros_like (grad ), grad ),
@@ -863,7 +887,7 @@ def __init__(
863887 init_b = init_scale
864888
865889 # ### prepare fetch_fn:
866- def fetch_fn (idx_obs , idx_genes = None ):
890+ def fetch_fn (idx ):
867891 r"""
868892 Documentation of tensorflow coding style in this function:
869893 tf.py_func defines a python function (the getters of the InputData object slots)
@@ -872,13 +896,8 @@ def fetch_fn(idx_obs, idx_genes=None):
872896 as explained below.
873897 """
874898 # Catch dimension collapse error if idx is only one element long, ie. 0D:
875- if len (idx_obs .shape ) == 0 :
876- idx_obs = tf .expand_dims (idx_obs , axis = 0 )
877- if idx_genes is None :
878- idx_genes = ...
879- else :
880- if len (idx_genes .shape ) == 0 :
881- idx_genes = tf .expand_dims (idx_genes , axis = 0 )
899+ if len (idx .shape ) == 0 :
900+ idx = tf .expand_dims (idx , axis = 0 )
882901
883902 X_tensor = tf .py_func (
884903 func = input_data .fetch_X ,
0 commit comments