Skip to content

Commit b0c2698

Browse files
switched singulrity detecction to condition number of hessian
1 parent 5c59a75 commit b0c2698

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -265,25 +265,31 @@ def iwls_step(
265265
if isinstance(delta_theta, dask.array.core.Array):
266266
delta_theta = delta_theta.compute()
267267

268-
with np.errstate(all="ignore"):
269-
if isinstance(a, dask.array.core.Array):
270-
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
271-
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
272-
if a.shape[0] != 1:
273-
delta_theta[:, idx_update] = dask.array.map_blocks(
274-
np.linalg.solve, a, b[:, :, None], chunks=b[:, :, None].shape
275-
).squeeze().T.compute()
276-
else:
268+
if isinstance(a, dask.array.core.Array):
269+
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
270+
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
271+
if a.shape[0] != 1:
272+
invertible = np.where(dask.array.map_blocks(
273+
np.linalg.cond, a, chunks=a.shape
274+
).compute() < 1 / sys.float_info.epsilon)[0]
275+
delta_theta[:, idx_update[invertible]] = dask.array.map_blocks(
276+
np.linalg.solve, a[invertible], b[invertible, :, None],
277+
chunks=b[invertible, :, None].shape
278+
).squeeze().T.compute()
279+
else:
280+
if np.linalg.cond(a.compute(), p=None) < 1 / sys.float_info.epsilon:
277281
delta_theta[:, idx_update] = np.expand_dims(
278282
np.linalg.solve(a[0], b[0]).compute(),
279283
axis=-1
280284
)
281-
else:
282-
delta_theta[:, idx_update] = np.linalg.solve(a, b).T
283-
linalg_errors = np.isnan(delta_theta[0, :])
284-
if np.any(linalg_errors):
285-
print("caught %i linalg errors" % np.sum(linalg_errors))
286-
delta_theta[:, linalg_errors] = 0.
285+
invertible = np.array([0])
286+
else:
287+
invertible = np.array([])
288+
else:
289+
invertible = np.where(np.linalg.cond(a, p=None) < 1 / sys.float_info.epsilon)[0]
290+
delta_theta[:, idx_update[invertible]] = np.linalg.solve(a[invertible], b[invertible]).T
291+
if invertible.shape[0] < len(idx_update):
292+
print("caught %i linalg singular matrix errors" % len(idx_update) - invertible.shape[0])
287293
# Via np.linalg.lsts:
288294
#delta_theta[:, idx_update] = np.concatenate([
289295
# np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
@@ -516,8 +522,10 @@ def finalize(self):
516522
"""
517523
# Read from numpy-IRLS estimator specific model:
518524
self._hessian = - self.model.fim.compute()
519-
with np.errstate(all="ignore"):
520-
self._fisher_inv = np.linalg.inv(- self._hessian)
525+
fisher_inv = np.zeros_like(self._hessian)
526+
invertible = np.where(np.linalg.cond(self._hessian, p=None) < 1 / sys.float_info.epsilon)[0]
527+
fisher_inv[invertible] = np.linalg.inv(- self._hessian[invertible])
528+
self._fisher_inv = fisher_inv
521529
self._jacobian = np.sum(np.abs(self.model.jac.compute() / self.model.x.shape[0]), axis=1)
522530
self._log_likelihood = self.model.ll_byfeature.compute()
523531
self._loss = np.sum(self._log_likelihood)

0 commit comments

Comments
 (0)