Skip to content

Commit 0139db1

Browse files
fixed dimension bug when single gene was invertible
1 parent 2fcbab6 commit 0139db1

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,16 @@ def iwls_step(
273273
invertible = np.where(dask.array.map_blocks(
274274
get_cond_number, a, chunks=a.shape
275275
).squeeze().compute() < 1 / sys.float_info.epsilon)[0]
276-
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
277-
np.linalg.solve, a[invertible], b[invertible, :, None],
278-
chunks=b[invertible, :, None].shape
279-
).squeeze().T.compute()
276+
if len(idx_update[invertible]) > 1:
277+
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
278+
np.linalg.solve, a[invertible], b[invertible, :, None],
279+
chunks=b[invertible, :, None].shape
280+
).squeeze().T.compute()
281+
elif len(idx_update[invertible]) == 1:
282+
delta_theta[:, idx_update[invertible]] = np.expand_dims(
283+
np.linalg.solve(a[invertible], b[invertible]).compute(),
284+
axis=-1
285+
)
280286
else:
281287
if np.linalg.cond(a.compute(), p=None) < 1 / sys.float_info.epsilon:
282288
delta_theta[:, idx_update] = np.expand_dims(

0 commit comments

Comments
 (0)