Skip to content

Commit 63e04a5

Browse files
changed handling of FIM in numpy
1 parent ed61688 commit 63e04a5

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

batchglm/train/numpy/base_glm/model.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ def ll_byfeature_j(self, j) -> np.ndarray:
7171
return np.sum(self.ll_j(j=j), axis=0)
7272

7373
@abc.abstractmethod
74-
def fim_weight(self) -> np.ndarray:
74+
def fim_weight_aa(self) -> np.ndarray:
7575
pass
7676

7777
@abc.abstractmethod
7878
def ybar(self) -> np.ndarray:
7979
pass
8080

8181
@abc.abstractmethod
82-
def fim_weight_j(self, j) -> np.ndarray:
82+
def fim_weight_aa_j(self, j) -> np.ndarray:
8383
pass
8484

8585
@abc.abstractmethod
@@ -95,12 +95,13 @@ def ybar_j(self, j) -> np.ndarray:
9595
pass
9696

9797
@property
98-
def fim(self) -> np.ndarray:
98+
def fim_aa(self) -> np.ndarray:
9999
"""
100+
Location-location coefficient block of FIM
100101
101102
:return: (features x inferred param x inferred param)
102103
"""
103-
w = self.fim_weight # (observations x features)
104+
w = self.fim_weight_aa # (observations x features)
104105
# constraints: (observed param x inferred param)
105106
# design: (observations x observed param)
106107
# w: (observations x features)
@@ -112,6 +113,30 @@ def fim(self) -> np.ndarray:
112113
xh
113114
)
114115

116+
@abc.abstractmethod
117+
def fim_ab(self) -> np.ndarray:
118+
pass
119+
120+
@property
121+
def fim_bb(self) -> np.ndarray:
122+
pass
123+
124+
@property
125+
def fim(self) -> np.ndarray:
126+
"""
127+
Full FIM
128+
129+
:return: (features x inferred param x inferred param)
130+
"""
131+
fim_aa = self.fim_aa
132+
fim_bb = self.fim_bb
133+
fim_ab = self.fim_ab
134+
fim_ba = np.transpose(fim_ab, axes=[0, 2, 1])
135+
return - np.concatenate([
136+
np.concatenate([fim_aa, fim_ab], axis=2),
137+
np.concatenate([fim_ba, fim_bb], axis=2)
138+
], axis=1)
139+
115140
@abc.abstractmethod
116141
def hessian_weight_aa(self) -> np.ndarray:
117142
pass
@@ -190,7 +215,7 @@ def jac_a(self) -> np.ndarray:
190215
191216
:return: (features x inferred param)
192217
"""
193-
w = self.fim_weight # (observations x features)
218+
w = self.fim_weight_aa # (observations x features)
194219
ybar = self.ybar # (observations x features)
195220
xh = np.matmul(self.design_loc, self.constraints_loc) # (observations x inferred param)
196221
return np.einsum(
@@ -207,7 +232,7 @@ def jac_a_j(self, j) -> np.ndarray:
207232
# Make sure that dimensionality of sliced array is kept:
208233
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64):
209234
j = [j]
210-
w = self.fim_weight_j(j=j) # (observations x features)
235+
w = self.fim_weight_aa_j(j=j) # (observations x features)
211236
ybar = self.ybar_j(j=j) # (observations x features)
212237
xh = np.matmul(self.design_loc, self.constraints_loc) # (observations x inferred param)
213238
return np.einsum(

batchglm/train/numpy/glm_nb/model.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
)
3737

3838
@property
39-
def fim_weight(self):
39+
def fim_weight_aa(self):
4040
"""
4141
4242
:return: observations x features
@@ -51,7 +51,7 @@ def ybar(self) -> np.ndarray:
5151
"""
5252
return np.asarray(self.x - self.location) / self.location
5353

54-
def fim_weight_j(self, j):
54+
def fim_weight_aa_j(self, j):
5555
"""
5656
5757
:return: observations x features
@@ -113,6 +113,38 @@ def jac_weight_b_j(self, j):
113113
const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu)
114114
return scale * (const1 + const2 + const3)
115115

116+
@property
117+
def fim_ab(self) -> np.ndarray:
118+
"""
119+
Location-scale coefficient block of FIM
120+
121+
The negative binomial model is not fit as whole with IRLS but only the location model.
122+
The location model is conditioned on the scale model estimates, which is why we only
123+
supply the FIM of the location model and return an empty FIM for scale model components.
124+
Note that there is also no closed form FIM for the scale-scale block. Returning a zero-array
125+
here leads to singular matrices for the whole location-scale FIM in some cases that throw
126+
linear algebra errors when inverted.
127+
128+
:return: (features x inferred param x inferred param)
129+
"""
130+
return np.zeros([self.b_var.shape[1], 0, 0])
131+
132+
@property
133+
def fim_bb(self) -> np.ndarray:
134+
"""
135+
Scale-scale coefficient block of FIM
136+
137+
The negative binomial model is not fit as whole with IRLS but only the location model.
138+
The location model is conditioned on the scale model estimates, which is why we only
139+
supply the FIM of the location model and return an empty FIM for scale model components.
140+
Note that there is also no closed form FIM for the scale-scale block. Returning a zero-array
141+
here leads to singular matrices for the whole location-scale FIM in some cases that throw
142+
linear algebra errors when inverted.
143+
144+
:return: (features x inferred param x inferred param)
145+
"""
146+
return np.zeros([self.b_var.shape[1], 0, 0])
147+
116148
@property
117149
def hessian_weight_ab(self):
118150
scale = self.scale
@@ -209,3 +241,18 @@ def fun(x, eta_loc, b_var, xh_scale):
209241
raise ValueError("type x %s not supported" % type(x))
210242
return self.np_clip_param(ll, "ll")
211243
return fun
244+
245+
def jac_b_handle(self):
246+
def fun(x, eta_loc, b_var, xh_scale):
247+
scale = np.exp(b_var)
248+
loc = np.exp(eta_loc)
249+
scale_plus_x = scale + x
250+
r_plus_mu = scale + loc
251+
252+
# Define graphs for individual terms of constant term of hessian:
253+
const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale)
254+
const2 = - scale_plus_x / r_plus_mu
255+
const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu)
256+
return scale * (const1 + const2 + const3)
257+
258+
return fun

0 commit comments

Comments
 (0)