Skip to content

Commit 5c59a75

Browse files
fixed issues with previous linalg fix
1 parent 6692b7b commit 5c59a75

File tree

4 files changed

+53
-5
lines changed

4 files changed

+53
-5
lines changed

batchglm/models/base_glm/simulator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,16 @@ def constraints_loc(self):
171171
@property
172172
def constraints_scale(self):
173173
return np.identity(n=self.b_var.shape[0])
174+
175+
def np_clip_param(
176+
self,
177+
param,
178+
name
179+
):
180+
# TODO: inherit this from somewhere?
181+
bounds_min, bounds_max = self.param_bounds(param.dtype)
182+
return np.clip(
183+
param,
184+
bounds_min[name],
185+
bounds_max[name]
186+
)

batchglm/models/glm_nb/external.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
from batchglm.models.base_glm import closedform_glm_mean, closedform_glm_scale
66

77
import batchglm.data as data_utils
8-
from batchglm.utils.linalg import groupwise_solve_lm
8+
from batchglm.utils.linalg import groupwise_solve_lm
9+
10+
from batchglm import pkg_constants

batchglm/models/glm_nb/simulator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .model import Model
44
from .external import _SimulatorGLM, InputDataGLM
5+
from .external import pkg_constants
56

67

78
class Simulator(_SimulatorGLM, Model):
@@ -58,3 +59,35 @@ def generate_data(self):
5859
design_scale_names=None
5960
)
6061

62+
def param_bounds(
63+
self,
64+
dtype
65+
):
66+
# TODO: inherit this from somewhere?
67+
dtype = np.dtype(dtype)
68+
dmin = np.finfo(dtype).min
69+
dmax = np.finfo(dtype).max
70+
dtype = dtype.type
71+
72+
sf = dtype(pkg_constants.ACCURACY_MARGIN_RELATIVE_TO_LIMIT)
73+
bounds_min = {
74+
"a_var": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
75+
"b_var": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
76+
"eta_loc": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
77+
"eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf,
78+
"loc": np.nextafter(0, np.inf, dtype=dtype),
79+
"scale": np.nextafter(0, np.inf, dtype=dtype),
80+
"likelihood": dtype(0),
81+
"ll": np.log(np.nextafter(0, np.inf, dtype=dtype)),
82+
}
83+
bounds_max = {
84+
"a_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
85+
"b_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
86+
"eta_loc": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
87+
"eta_scale": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf,
88+
"loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
89+
"scale": np.nextafter(dmax, -np.inf, dtype=dtype) / sf,
90+
"likelihood": dtype(1),
91+
"ll": dtype(0),
92+
}
93+
return bounds_min, bounds_max

batchglm/train/numpy/base_glm/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def iwls_step(
265265
if isinstance(delta_theta, dask.array.core.Array):
266266
delta_theta = delta_theta.compute()
267267

268-
with np.errstate(linalg="ignore"):
268+
with np.errstate(all="ignore"):
269269
if isinstance(a, dask.array.core.Array):
270270
# Have to use a workaround to solve problems in parallel in dask here. This workaround does
271271
# not work if there is only a single problem, ie. if the first dimension of a and b has length 1.
@@ -280,10 +280,10 @@ def iwls_step(
280280
)
281281
else:
282282
delta_theta[:, idx_update] = np.linalg.solve(a, b).T
283-
linalg_errors = np.isnan(delta_theta[:, 0, 0])
283+
linalg_errors = np.isnan(delta_theta[0, :])
284284
if np.any(linalg_errors):
285285
print("caught %i linalg errors" % np.sum(linalg_errors))
286-
delta_theta[linalg_errors, :, :] = 0.
286+
delta_theta[:, linalg_errors] = 0.
287287
# Via np.linalg.lsts:
288288
#delta_theta[:, idx_update] = np.concatenate([
289289
# np.expand_dims(np.linalg.lstsq(a[i, :, :], b[i, :])[0], axis=-1)
@@ -516,7 +516,7 @@ def finalize(self):
516516
"""
517517
# Read from numpy-IRLS estimator specific model:
518518
self._hessian = - self.model.fim.compute()
519-
with np.errstate(linalg="ignore"):
519+
with np.errstate(all="ignore"):
520520
self._fisher_inv = np.linalg.inv(- self._hessian)
521521
self._jacobian = np.sum(np.abs(self.model.jac.compute() / self.model.x.shape[0]), axis=1)
522522
self._log_likelihood = self.model.ll_byfeature.compute()

0 commit comments

Comments
 (0)