Skip to content

Commit e746783

Browse files
switched final fim inverse to tf linalg inv from pinv
1 parent 71bbe7e commit e746783

File tree

5 files changed

+7
-29
lines changed

5 files changed

+7
-29
lines changed

batchglm/pkg_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TRUST_REGION_ETA1 = 0.25
3535
TRUST_REGION_ETA2 = 0.25
3636
TRUST_REGION_T1 = 0.5 # Fast collapse to avoid trailing.
37-
TRUST_REGION_T2 = 2. # Allow expansion if not shrinking.
37+
TRUST_REGION_T2 = 1.5 # Allow expansion if not shrinking.
3838
TRUST_REGION_UPPER_BOUND = 1e5
3939

4040
# Convergence hyper-parameters:

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def finalize(self):
329329
a_var = self.session.run(self.model.a_var)
330330
b_var = self.session.run(self.model.b_var)
331331
fisher_inv = self.session.run(self.model.fisher_inv)
332-
hessians = self.session.run(self.model.hessians)
332+
hessian = self.session.run(self.model.hessian)
333333
jacobian = self.session.run(self.model.gradients)
334334
log_likelihood = self.session.run(self.model.log_likelihood)
335335
loss = self.session.run(self.model.loss)
@@ -339,7 +339,7 @@ def finalize(self):
339339
self.model._a_var = a_var
340340
self.model._b_var = b_var
341341
self._fisher_inv = fisher_inv
342-
self._hessians = hessians
342+
self._hessian = hessian
343343
self._jacobian = jacobian
344344
self._log_likelihood = log_likelihood
345345
self._loss = loss

batchglm/train/tf/base_glm_all/estimator_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ def __init__(
530530
)
531531
self.loss = self.full_data_model.loss_final
532532
self.log_likelihood = self.full_data_model.log_likelihood_final
533-
self.hessians = self.full_data_model.hessians_final
534-
self.fisher_inv = op_utils.pinv(-self.full_data_model.hessians_final) # TODO switch for fim?
533+
self.hessian = self.full_data_model.hessians_final
534+
self.fisher_inv = tf.linalg.inv(-self.full_data_model.hessians_final) # TODO switch for fim?
535535
# Summary statistics on feature-wise model gradients:
536536
self.gradients = tf.reduce_sum(tf.abs(self.full_data_model.neg_jac_final / num_observations), axis=1)
537537

batchglm/train/tf/glm_norm/fim.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ def _weight_fim_aa(
1414
loc,
1515
scale
1616
):
17-
scalar_one = tf.constant(1, shape=(), dtype=self.dtype)
18-
W = tf.square(tf.divide(scalar_one, scale))
17+
W = tf.square(tf.divide(tf.ones_like(scale), scale))
1918

2019
return W
2120

@@ -24,7 +23,6 @@ def _weight_fim_bb(
2423
loc,
2524
scale
2625
):
27-
scalar_two = tf.constant(2, shape=(), dtype=self.dtype)
28-
W = scalar_two * tf.ones_like(loc)
26+
W = tf.constant(2, shape=loc.shape, dtype=self.dtype)
2927

3028
return W

batchglm/train/tf/ops.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,6 @@ def swap_dims(tensor, axis0, axis1, exec_transpose=True, return_perm=False, name
3333
return perm1
3434

3535

36-
def pinv(matrix, threshold=1e-5):
37-
"""
38-
Calculate the Moore-Penrose pseudo-inverse of the last two dimensions of some matrix.
39-
40-
E.g. if `matrix` has some shape [..., K, L, M, N], this method will inverse each [M, N] matrix.
41-
42-
:param matrix: The matrix to invert
43-
:param threshold: threshold value
44-
:return: the pseudo-inverse of `matrix`
45-
"""
46-
47-
s, u, v = tf.linalg.svd(matrix) # , full_matrices=True, compute_uv=True)
48-
49-
adj_threshold = tf.reduce_max(s, axis=-1, keepdims=True) * threshold
50-
s_inv = tf.where(s > tf.broadcast_to(adj_threshold, s.shape), tf.math.reciprocal(s), tf.zeros_like(s))
51-
s_inv = tf.linalg.diag(s_inv)
52-
53-
return v @ (s_inv @ swap_dims(u, axis0=-1, axis1=-2))
54-
55-
5636
def stacked_lstsq(L, b, rcond=1e-10, name="stacked_lstsq"):
5737
r"""
5838
Solve `Lx = b`, via SVD least squares cutting of small singular values

0 commit comments

Comments
 (0)