Skip to content

Commit 6692b7b

Browse files
added linalg error propagation as nan which is then separately dealt with as singular matrix fix
1 parent 9a7d92b commit 6692b7b

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import multiprocessing
55
import numpy as np
6-
import pprint
76
import scipy
87
import scipy.sparse
98
import scipy.optimize
@@ -266,20 +265,25 @@ def iwls_step(
266265
if isinstance(delta_theta, dask.array.core.Array):
267266
delta_theta = delta_theta.compute()
268267

269-
if isinstance(a, dask.array.core.Array):
270-
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
271-
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
272-
if a.shape[0] != 1:
273-
delta_theta[:, idx_update] = dask.array.map_blocks(
274-
np.linalg.solve, a, b[:, :, None], chunks=b[:, :, None].shape
275-
).squeeze().T.compute()
268+
with np.errstate(linalg="ignore"):
269+
if isinstance(a, dask.array.core.Array):
270+
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
271+
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
272+
if a.shape[0] != 1:
273+
delta_theta[:, idx_update] = dask.array.map_blocks(
274+
np.linalg.solve, a, b[:, :, None], chunks=b[:, :, None].shape
275+
).squeeze().T.compute()
276+
else:
277+
delta_theta[:, idx_update] = np.expand_dims(
278+
np.linalg.solve(a[0], b[0]).compute(),
279+
axis=-1
280+
)
276281
else:
277-
delta_theta[:, idx_update] = np.expand_dims(
278-
np.linalg.solve(a[0], b[0]).compute(),
279-
axis=-1
280-
)
281-
else:
282-
delta_theta[:, idx_update] = np.linalg.solve(a, b).T
282+
delta_theta[:, idx_update] = np.linalg.solve(a, b).T
283+
linalg_errors = np.isnan(delta_theta[:, 0, 0])
284+
if np.any(linalg_errors):
285+
print("caught %i linalg errors" % np.sum(linalg_errors))
286+
delta_theta[linalg_errors, :, :] = 0.
283287
# Via np.linalg.lsts:
284288
#delta_theta[:, idx_update] = np.concatenate([
285289
# np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
@@ -512,7 +516,8 @@ def finalize(self):
512516
"""
513517
# Read from numpy-IRLS estimator specific model:
514518
self._hessian = - self.model.fim.compute()
515-
self._fisher_inv = np.linalg.inv(- self._hessian)
519+
with np.errstate(linalg="ignore"):
520+
self._fisher_inv = np.linalg.inv(- self._hessian)
516521
self._jacobian = np.sum(np.abs(self.model.jac.compute() / self.model.x.shape[0]), axis=1)
517522
self._log_likelihood = self.model.ll_byfeature.compute()
518523
self._loss = np.sum(self._log_likelihood)

0 commit comments

Comments
 (0)