We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2fcbab6 commit 0139db1Copy full SHA for 0139db1
batchglm/train/numpy/base_glm/estimator.py
@@ -273,10 +273,16 @@ def iwls_step(
273
invertible = np.where(dask.array.map_blocks(
274
get_cond_number, a, chunks=a.shape
275
).squeeze().compute() < 1 / sys.float_info.epsilon)[0]
276
- delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
277
- np.linalg.solve, a[invertible], b[invertible, :, None],
278
- chunks=b[invertible, :, None].shape
279
- ).squeeze().T.compute()
+ if len(idx_update[invertible]) > 1:
+ delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
+ np.linalg.solve, a[invertible], b[invertible, :, None],
+ chunks=b[invertible, :, None].shape
280
+ ).squeeze().T.compute()
281
+ elif len(idx_update[invertible]) == 1:
282
+ delta_theta[:, idx_update[invertible]] = np.expand_dims(
283
+ np.linalg.solve(a[invertible], b[invertible]).compute(),
284
+ axis=-1
285
+ )
286
else:
287
if np.linalg.cond(a.compute(), p=None) < 1 / sys.float_info.epsilon:
288
delta_theta[:, idx_update] = np.expand_dims(
0 commit comments