Skip to content

Commit dd17940

Browse files
added feature batch size variables
1 parent f9e5d80 commit dd17940

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
num_design_scale_params,
150150
graph: tf.Graph = None,
151151
batch_size=500,
152+
feature_batch_size=None,
152153
init_a=None,
153154
init_b=None,
154155
constraints_loc=None,
@@ -162,6 +163,7 @@ def __init__(
162163
self.num_design_loc_params = num_design_loc_params
163164
self.num_design_scale_params = num_design_scale_params
164165
self.batch_size = batch_size
166+
self.feature_batch_size = feature_batch_size
165167

166168
# initial graph elements
167169
with self.graph.as_default():
@@ -861,7 +863,7 @@ def __init__(
861863
init_b = init_scale
862864

863865
# ### prepare fetch_fn:
864-
def fetch_fn(idx):
866+
def fetch_fn(idx_obs, idx_genes=None):
865867
r"""
866868
Documentation of tensorflow coding style in this function:
867869
tf.py_func defines a python function (the getters of the InputData object slots)
@@ -870,8 +872,13 @@ def fetch_fn(idx):
870872
as explained below.
871873
"""
872874
# Catch dimension collapse error if idx is only one element long, ie. 0D:
873-
if len(idx.shape) == 0:
874-
idx = tf.expand_dims(idx, axis=0)
875+
if len(idx_obs.shape) == 0:
876+
idx_obs = tf.expand_dims(idx_obs, axis=0)
877+
if idx_genes is None:
878+
idx_genes = ...
879+
else:
880+
if len(idx_genes.shape) == 0:
881+
idx_genes = tf.expand_dims(idx_genes, axis=0)
875882

876883
X_tensor = tf.py_func(
877884
func=input_data.fetch_X,

0 commit comments

Comments
 (0)