Skip to content

Commit 42fbed3

Browse files
fixed some small bugs and cleaned up
1 parent a573fcd commit 42fbed3

File tree

4 files changed

+20
-15
lines changed

4 files changed

+20
-15
lines changed

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,29 +185,29 @@ def init_par(
185185
logger.info("Should train r: %s", self._train_scale)
186186

187187
if init_model is not None:
188+
# Locations model:
188189
if isinstance(init_a, str) and (init_a.lower() == "auto" or init_a.lower() == "init_model"):
189-
# location
190190
my_loc_names = set(self.input_data.design_loc_names.values)
191191
my_loc_names = my_loc_names.intersection(init_model.input_data.design_loc_names.values)
192192

193193
init_loc = np.zeros([self.input_data.num_loc_params, self.input_data.num_features])
194194
for parm in my_loc_names:
195195
init_idx = np.where(init_model.input_data.design_loc_names == parm)
196-
my_idx = np.where(input_data.design_loc_names == parm)
196+
my_idx = np.where(self.input_data.design_loc_names == parm)
197197
init_loc[my_idx] = init_model.par_link_loc[init_idx]
198198

199199
init_a = init_loc
200200
logger.info("Using initialization based on input model for mean")
201201

202+
# Scale model:
202203
if isinstance(init_b, str) and (init_b.lower() == "auto" or init_b.lower() == "init_model"):
203-
# scale
204-
my_scale_names = set(input_data.design_scale_names.values)
204+
my_scale_names = set(self.input_data.design_scale_names.values)
205205
my_scale_names = my_scale_names.intersection(init_model.input_data.design_scale_names.values)
206206

207207
init_scale = np.zeros([self.input_data.num_scale_params, self.input_data.num_features])
208208
for parm in my_scale_names:
209209
init_idx = np.where(init_model.input_data.design_scale_names == parm)
210-
my_idx = np.where(input_data.design_scale_names == parm)
210+
my_idx = np.where(self.input_data.design_scale_names == parm)
211211
init_scale[my_idx] = init_model.par_link_scale[init_idx]
212212

213213
init_b = init_scale

batchglm/train/tf/glm_nb/fim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ def _W_bb(
3232
digamma_r = tf.math.digamma(x=r)
3333
digamma_r_plus_mu = tf.math.digamma(x=r_plus_mu)
3434

35-
const1 = tf.multiply(scalar_two, tf.add( # [observations, features]
35+
const1 = tf.multiply(scalar_two, tf.add(
3636
digamma_r,
3737
digamma_r_plus_mu
3838
))
39-
const2 = tf.multiply(r, tf.add( # [observations, features]
39+
const2 = tf.multiply(r, tf.add(
4040
tf.math.polygamma(a=scalar_one, x=r),
4141
tf.math.polygamma(a=scalar_one, x=r_plus_mu)
4242
))

batchglm/train/tf/glm_nb/hessians.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def _W_ab(
1616
r,
1717
):
1818
const = tf.multiply(
19-
mu * r, # [observations, features]
19+
mu * r,
2020
tf.divide(
21-
X - mu, # [observations, features]
21+
X - mu,
2222
tf.square(mu + r)
2323
)
2424
)
@@ -31,7 +31,7 @@ def _W_aa(
3131
r,
3232
):
3333
const = tf.negative(tf.multiply(
34-
mu, # [observations, features]
34+
mu,
3535
tf.divide(
3636
(X / r) + 1,
3737
tf.square((mu / r) + 1)
@@ -51,11 +51,11 @@ def _W_bb(
5151
r_plus_mu = r + mu
5252
r_plus_x = r + X
5353
# Define graphs for individual terms of constant term of hessian:
54-
const1 = tf.add( # [observations, features]
54+
const1 = tf.add(
5555
tf.math.digamma(x=r_plus_x),
5656
r * tf.math.polygamma(a=scalar_one, x=r_plus_x)
5757
)
58-
const2 = tf.negative(tf.add( # [observations, features]
58+
const2 = tf.negative(tf.add(
5959
tf.math.digamma(x=r),
6060
r * tf.math.polygamma(a=scalar_one, x=r)
6161
))
@@ -66,11 +66,11 @@ def _W_bb(
6666
),
6767
tf.square(r_plus_mu)
6868
))
69-
const4 = tf.add( # [observations, features]
69+
const4 = tf.add(
7070
tf.log(r),
7171
scalar_two - tf.log(r_plus_mu)
7272
)
73-
const = tf.add_n([const1, const2, const3, const4]) # [observations, features]
73+
const = tf.add_n([const1, const2, const3, const4])
7474
const = tf.multiply(r, const)
7575
return const
7676

batchglm/unit_test/base_glm/test_extreme_values_glm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import abc
2+
import logging
23
from typing import List
34
import unittest
45
import numpy as np
56

7+
import batchglm.api as glm
68
from batchglm.models.base_glm import _Estimator_GLM, InputData, _Simulator_GLM
79

10+
glm.setup_logging(verbosity="WARNING", stream="STDOUT")
11+
logger = logging.getLogger(__name__)
12+
813

914
class _Test_ExtremValues_GLM_Estim():
1015

@@ -23,7 +28,7 @@ def estimate(
2328
"convergence_criteria": "all_converged_ll",
2429
"stopping_criteria": 1e-4,
2530
"use_batching": False,
26-
"optim_algo": "Newton",
31+
"optim_algo": "IRLS",
2732
},
2833
])
2934

0 commit comments

Comments
 (0)