Skip to content

Commit 64ad498

Browse files
fixed bug with computing condition number in dask in chunks of genes
1 parent b0c2698 commit 64ad498

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ def iwls_step(
269269
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
270270
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
271271
if a.shape[0] != 1:
272+
get_cond_number = lambda x: np.expand_dims(np.expand_dims(np.linalg.cond(x, p=None), axis=-1), axis=-1)
272273
invertible = np.where(dask.array.map_blocks(
273-
np.linalg.cond, a, chunks=a.shape
274-
).compute() < 1 / sys.float_info.epsilon)[0]
274+
get_cond_number, a, chunks=a.shape
275+
).squeeze().compute() < 1 / sys.float_info.epsilon)[0]
275276
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
276277
np.linalg.solve, a[invertible], b[invertible, :, None],
277278
chunks=b[invertible, :, None].shape

batchglm/unit_test/test_acc_glm_all_numpy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def __init__(
3636
design_scale_names=simulator.input_data.design_scale_names,
3737
constraints_loc=simulator.input_data.constraints_loc,
3838
constraints_scale=simulator.input_data.constraints_scale,
39-
size_factors=simulator.input_data.size_factors
39+
size_factors=simulator.input_data.size_factors,
40+
chunk_size_cells=int(1e9),
41+
chunk_size_genes=2
4042
)
4143
else:
4244
input_data = InputDataGLM(
@@ -47,7 +49,9 @@ def __init__(
4749
design_scale_names=simulator.input_data.design_scale_names,
4850
constraints_loc=simulator.input_data.constraints_loc,
4951
constraints_scale=simulator.input_data.constraints_scale,
50-
size_factors=simulator.input_data.size_factors
52+
size_factors=simulator.input_data.size_factors,
53+
chunk_size_cells=int(1e9),
54+
chunk_size_genes=2
5155
)
5256

5357
self.estimator = Estimator(

0 commit comments

Comments
 (0)