@@ -149,6 +149,7 @@ def __init__(
149149 num_design_scale_params ,
150150 graph : tf .Graph = None ,
151151 batch_size = 500 ,
152+ feature_batch_size = None ,
152153 init_a = None ,
153154 init_b = None ,
154155 constraints_loc = None ,
@@ -162,6 +163,7 @@ def __init__(
162163 self .num_design_loc_params = num_design_loc_params
163164 self .num_design_scale_params = num_design_scale_params
164165 self .batch_size = batch_size
166+ self .feature_batch_size = feature_batch_size
165167
166168 # initial graph elements
167169 with self .graph .as_default ():
@@ -861,7 +863,7 @@ def __init__(
861863 init_b = init_scale
862864
863865 # ### prepare fetch_fn:
864- def fetch_fn (idx ):
866+ def fetch_fn (idx_obs , idx_genes = None ):
865867 r"""
866868 Documentation of tensorflow coding style in this function:
867869 tf.py_func defines a python function (the getters of the InputData object slots)
@@ -870,8 +872,13 @@ def fetch_fn(idx):
870872 as explained below.
871873 """
872874 # Catch dimension collapse error if idx is only one element long, ie. 0D:
873- if len (idx .shape ) == 0 :
874- idx = tf .expand_dims (idx , axis = 0 )
875+ if len (idx_obs .shape ) == 0 :
876+ idx_obs = tf .expand_dims (idx_obs , axis = 0 )
877+ if idx_genes is None :
878+ idx_genes = ...
879+ else :
880+ if len (idx_genes .shape ) == 0 :
881+ idx_genes = tf .expand_dims (idx_genes , axis = 0 )
875882
876883 X_tensor = tf .py_func (
877884 func = input_data .fetch_X ,
0 commit comments