Skip to content

Commit 97c4407

Browse files
further changes to comply with diffxpy api
1 parent e746783 commit 97c4407

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

batchglm/models/base/input.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ def feature_isallzero(self):
7272
return self._feature_allzero
7373

7474
def fetch_x_dense(self, idx):
75-
return self.x[idx]
75+
assert isinstance(self.x, np.ndarray), "tried to fetch dense from non ndarray"
76+
77+
return self.x[idx, :]
7678

7779
def fetch_x_sparse(self, idx):
78-
assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr matrix"
80+
assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr_matrix"
7981

80-
data = self.x[idx]
82+
data = self.x[idx, :]
8183

8284
data_idx = np.asarray(np.vstack(data.nonzero()).T, np.int64)
8385
data_val = np.asarray(data.data, np.float64)

batchglm/models/base_glm/input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,10 @@ def num_scale_params(self):
133133
return self.constraints_scale.shape[1]
134134

135135
def fetch_design_loc(self, idx):
136-
return self.design_loc[idx]
136+
return self.design_loc[idx, :]
137137

138138
def fetch_design_scale(self, idx):
139-
return self.design_scale[idx]
139+
return self.design_scale[idx, :]
140140

141141
def fetch_size_factors(self, idx):
142142
return self.size_factors[idx]

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def fetch_fn(idx):
168168
num_design_scale_params=input_data.num_design_scale_params,
169169
num_loc_params=input_data.num_loc_params,
170170
num_scale_params=input_data.num_scale_params,
171-
batch_size=batch_size,
171+
batch_size=np.min([batch_size, input_data.x.shape[0]]),
172172
graph=graph,
173173
init_a=init_a,
174174
init_b=init_b,

batchglm/unit_test/test_graph_glm_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
else:
3535
raise ValueError("noise_model not recognized")
3636

37-
batch_size = 100
37+
batch_size = 200
3838
provide_optimizers = {
3939
"gd": False, "adam": False, "adagrad": False, "rmsprop": False,
4040
"nr": False, "nr_tr": False,

0 commit comments

Comments
 (0)