Skip to content

Commit 5707d0a

Browse files
Merge pull request #99 from theislab/fix_size_factors_init
bugfix: size_factors_init two times expand dims
2 parents 8dcd460 + d182a84 commit 5707d0a

File tree

4 files changed

+9
-29
lines changed

4 files changed

+9
-29
lines changed

batchglm/models/base_glm/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,4 @@ def fetch_design_scale(self, idx):
197197
return self.design_scale[idx, :]
198198

199199
def fetch_size_factors(self, idx):
200-
return self.size_factors[idx]
200+
return self.size_factors[idx, :]

batchglm/train/tf1/glm_beta/estimator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,6 @@ def init_par(
169169
$$
170170
"""
171171

172-
size_factors_init = input_data.size_factors
173-
174172
if init_model is None:
175173
groupwise_means = None
176174
init_a_str = None
@@ -185,7 +183,7 @@ def init_par(
185183
x=input_data.x,
186184
design_loc=input_data.design_loc,
187185
constraints_loc=input_data.constraints_loc,
188-
size_factors=size_factors_init,
186+
size_factors=input_data.size_factors,
189187
link_fn=lambda mean: np.log(
190188
1/(1/self.np_clip_param(mean, "mean")-1)
191189
)
@@ -221,7 +219,7 @@ def init_par(
221219
x=input_data.x,
222220
design_scale=input_data.design_scale[:, [0]],
223221
constraints=input_data.constraints_scale[[0], :][:, [0]],
224-
size_factors=size_factors_init,
222+
size_factors=input_data.size_factors,
225223
groupwise_means=None,
226224
link_fn=lambda samplesize: np.log(self.np_clip_param(samplesize, "samplesize"))
227225
)
@@ -248,7 +246,7 @@ def init_par(
248246
x=input_data.x,
249247
design_scale=input_data.design_scale,
250248
constraints=input_data.constraints_scale,
251-
size_factors=size_factors_init,
249+
size_factors=input_data.size_factors,
252250
groupwise_means=groupwise_means,
253251
link_fn=lambda samplesize: np.log(self.np_clip_param(samplesize, "samplesize"))
254252
)
@@ -291,4 +289,3 @@ def init_par(
291289
logging.getLogger("batchglm").debug("Using initialization based on input model for dispersion")
292290

293291
return init_a, init_b
294-

batchglm/train/tf1/glm_nb/estimator.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,6 @@ def init_par(
176176
$$
177177
"""
178178

179-
size_factors_init = input_data.size_factors
180-
if size_factors_init is not None:
181-
size_factors_init = np.expand_dims(size_factors_init, axis=1)
182-
size_factors_init = np.broadcast_to(
183-
array=size_factors_init,
184-
shape=[input_data.num_observations, input_data.num_features]
185-
)
186-
187179
if init_model is None:
188180
groupwise_means = None
189181
init_a_str = None
@@ -198,7 +190,7 @@ def init_par(
198190
x=input_data.x,
199191
design_loc=input_data.design_loc,
200192
constraints_loc=input_data.constraints_loc,
201-
size_factors=size_factors_init,
193+
size_factors=input_data.size_factors,
202194
link_fn=lambda mu: np.log(self.np_clip_param(mu, "mu"))
203195
)
204196

@@ -239,7 +231,7 @@ def init_par(
239231
x=input_data.x,
240232
design_scale=input_data.design_scale[:, [0]],
241233
constraints=input_data.constraints_scale[[0], :][:, [0]],
242-
size_factors=size_factors_init,
234+
size_factors=input_data.size_factors,
243235
groupwise_means=None,
244236
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
245237
)
@@ -267,7 +259,7 @@ def init_par(
267259
x=input_data.x,
268260
design_scale=input_data.design_scale,
269261
constraints=input_data.constraints_scale,
270-
size_factors=size_factors_init,
262+
size_factors=input_data.size_factors,
271263
groupwise_means=groupwise_means,
272264
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
273265
)

batchglm/train/tf1/glm_norm/estimator.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,6 @@ def init_par(
172172
$$
173173
"""
174174

175-
size_factors_init = input_data.size_factors
176-
if size_factors_init is not None:
177-
size_factors_init = np.expand_dims(size_factors_init, axis=1)
178-
size_factors_init = np.broadcast_to(
179-
array=size_factors_init,
180-
shape=[input_data.num_observations, input_data.num_features]
181-
)
182-
183175
sf_given = False
184176
if input_data.size_factors is not None:
185177
if np.any(np.abs(input_data.size_factors - 1.) > 1e-8):
@@ -268,7 +260,7 @@ def init_par(
268260
x=input_data.x,
269261
design_scale=input_data.design_scale,
270262
constraints=input_data.constraints_scale,
271-
size_factors=size_factors_init,
263+
size_factors=input_data.size_factors,
272264
groupwise_means=groupwise_means,
273265
link_fn=lambda sd: np.log(self.np_clip_param(sd, "sd"))
274266
)
@@ -282,7 +274,7 @@ def init_par(
282274
x=input_data.x,
283275
design_scale=input_data.design_scale[:, [0]],
284276
constraints=input_data.constraints_scale[[0], :][:, [0]],
285-
size_factors=size_factors_init,
277+
size_factors=input_data.size_factors,
286278
groupwise_means=None,
287279
link_fn=lambda sd: np.log(self.np_clip_param(sd, "sd"))
288280
)
@@ -331,4 +323,3 @@ def init_par(
331323
logger.debug("Using initialization based on input model for dispersion")
332324

333325
return init_a, init_b
334-

0 commit comments

Comments
 (0)