Skip to content

Commit 2a6c262

Browse files
fixed interface to diffxpy
1 parent 90e3bdd commit 2a6c262

File tree

8 files changed

+63
-64
lines changed

8 files changed

+63
-64
lines changed

batchglm/models/base/estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,15 @@ def validate_data(self, **kwargs):
104104
@property
105105
def input_data(self):
106106
return self._input_data
107+
108+
@property
109+
def X(self):
110+
return self.input_data.X
111+
112+
@property
113+
def features(self):
114+
return self.input_data.features
115+
116+
@property
117+
def loss(self):
118+
return self.params["loss"]

batchglm/models/base_glm/estimator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ESTIMATOR_PARAMS = MODEL_PARAMS.copy()
1212
ESTIMATOR_PARAMS.update({
1313
"loss": (),
14+
"full_log_likelihood": ("features",),
1415
"gradient": ("features",),
1516
"hessians": ("features", "delta_var0", "delta_var1"),
1617
"fisher_inv": ("features", "delta_var0", "delta_var1"),
@@ -27,12 +28,8 @@ def __init__(self):
2728
super(_EstimatorStore_XArray_Base, self).__init__()
2829

2930
@property
30-
def input_data(self):
31-
return self._input_data
32-
33-
@property
34-
def loss(self):
35-
return self.params["loss"]
31+
def log_likelihood(self):
32+
return self.params["full_log_likelihood"]
3633

3734
@property
3835
def gradient(self):

batchglm/models/glm_nb/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, estim: AbstractEstimator):
2323
# causes evaluation of the properties that have not been computed during
2424
# training, such as the hessian.
2525
params = estim.to_xarray(
26-
["a_var", "b_var", "loss", "gradient", "hessians", "fisher_inv"],
26+
["a_var", "b_var", "loss", "full_log_likelihood", "gradient", "hessians", "fisher_inv"],
2727
coords=input_data.data
2828
)
2929

batchglm/train/tf/base/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import abc
2-
from enum import Enum
32
from typing import Dict, Any, Union, List, Iterable
43

54
import os
@@ -247,7 +246,7 @@ def train(self, *args,
247246
)
248247

249248
tf.logging.info(
250-
"Step: %d\tloss: %f\t models converged %i",
249+
"Step: \t%d\t loss: %f\t models converged %i",
251250
train_step,
252251
global_loss,
253252
np.sum(self.model.model_vars.converged).astype("int32")

batchglm/train/tf/base_glm/estimator_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(
218218
# with the Cholesky decomposition. This information is
219219
# passed here with psd=True.
220220
irls_update_a_full, irls_update_a_batched = self.build_updates(
221-
full_lhs=self.full_data_model.fim.fim_a,
221+
full_lhs=self.full_data_model.fim_train.fim_a,
222222
batched_lhs=self.batch_fim.fim_a,
223223
full_rhs=self.full_data_model.neg_jac_train_a,
224224
batched_rhs=self.batch_jac.neg_jac_a,
@@ -231,7 +231,7 @@ def __init__(
231231

232232
if train_r:
233233
irls_update_b_full, irls_update_b_batched = self.build_updates(
234-
full_lhs=self.full_data_model.fim.fim_b,
234+
full_lhs=self.full_data_model.fim_train.fim_b,
235235
batched_lhs=self.batch_fim.fim_b,
236236
full_rhs=self.full_data_model.neg_jac_train_b,
237237
batched_rhs=self.batch_jac.neg_jac_b,

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@ def train(self, *args,
293293
if newton_type_mode:
294294
if learning_rate != 1:
295295
logger.warning(
296-
"Newton-rhapson or IRLS in base_glm_all is used with learing rate " +
296+
"Newton-rhapson or IRLS in base_glm_all is used with learning rate " +
297297
str(learning_rate) +
298-
". Newton-rhapson and IRLS should only be used with learing rate = 1."
298+
". Newton-rhapson and IRLS should only be used with learning rate = 1."
299299
)
300300

301301
# Report all parameters after all defaults were imputed in settings:
@@ -360,17 +360,21 @@ def a_var(self):
360360
def b_var(self):
361361
return self.to_xarray("b_var", coords=self.input_data.data.coords)
362362

363+
@property
364+
def loss(self):
365+
return self.to_xarray("full_loss")
366+
363367
@property
364368
def batch_loss(self):
365369
return self.to_xarray("loss")
366370

367371
@property
368-
def batch_gradient(self):
369-
return self.to_xarray("gradient", coords=self.input_data.data.coords)
372+
def log_likelihood(self):
373+
return self.to_xarray("full_log_likelihood", coords=self.input_data.data.coords)
370374

371375
@property
372-
def loss(self):
373-
return self.to_xarray("full_loss")
376+
def batch_gradient(self):
377+
return self.to_xarray("gradient", coords=self.input_data.data.coords)
374378

375379
@property
376380
def gradient(self):
@@ -388,9 +392,9 @@ def finalize(self):
388392
if self.noise_model == "nb":
389393
from .external_nb import EstimatorStoreXArray
390394
else:
391-
raise ValueError("noise model not rewcognized")
395+
raise ValueError("noise model not recognized")
392396

393-
logger.debug("Collect and compute ouptut")
397+
logger.debug("Collect and compute output")
394398
store = EstimatorStoreXArray(self)
395399
logger.debug("Closing session")
396400
self.close_session()

batchglm/train/tf/base_glm_all/estimator_graph.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def map_model(idx, data) -> BasicModelGraph:
130130
else:
131131
hessians_train = hessians_full
132132

133-
fim_train = FIM(
133+
fim_full = FIM(
134134
batched_data=batched_data,
135135
sample_indices=sample_indices,
136136
constraints_loc=constraints_loc,
@@ -139,10 +139,27 @@ def map_model(idx, data) -> BasicModelGraph:
139139
mode=pkg_constants.HESSIAN_MODE,
140140
noise_model=noise_model,
141141
iterator=True,
142-
update_a=train_a,
143-
update_b=train_b,
142+
update_a=True,
143+
update_b=True,
144144
dtype=dtype
145145
)
146+
# Fisher information matrix of submodel which is to be trained.
147+
if not train_a or not train_b:
148+
fim_train = FIM(
149+
batched_data=batched_data,
150+
sample_indices=sample_indices,
151+
constraints_loc=constraints_loc,
152+
constraints_scale=constraints_scale,
153+
model_vars=model_vars,
154+
mode=pkg_constants.HESSIAN_MODE,
155+
noise_model=noise_model,
156+
iterator=True,
157+
update_a=train_a,
158+
update_b=train_b,
159+
dtype=dtype
160+
)
161+
else:
162+
fim_train = fim_full
146163

147164
with tf.name_scope("jacobians"):
148165
# Jacobian of full model for reporting.
@@ -210,7 +227,8 @@ def map_model(idx, data) -> BasicModelGraph:
210227
self.neg_hessian = hessians_full.neg_hessian
211228
self.neg_hessian_train = hessians_train.neg_hessian
212229

213-
self.fim = fim_train
230+
self.fim_full = fim_full
231+
self.fim_train = fim_train
214232

215233

216234
class EstimatorGraphAll(EstimatorGraphGLM):
@@ -439,20 +457,7 @@ def __init__(
439457
noise_model=noise_model,
440458
dtype=dtype
441459
)
442-
full_data_loss = full_data_model.loss
443-
fisher_inv = op_utils.pinv(full_data_model.neg_hessian)
444-
445-
# with tf.name_scope("hessian_diagonal"):
446-
# hessian_diagonal = [
447-
# tf.map_fn(
448-
# # elems=tf.transpose(hess, perm=[2, 0, 1]),
449-
# elems=hess,
450-
# fn=tf.diag_part,
451-
# parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
452-
# )
453-
# for hess in full_data_model.hessians
454-
# ]
455-
# fisher_a, fisher_b = hessian_diagonal
460+
full_data_fisher_inv = op_utils.pinv(full_data_model.neg_hessian) # TODO switch for fim
456461

457462
mu = full_data_model.mu
458463
r = full_data_model.r
@@ -473,16 +478,17 @@ def __init__(
473478
self.r = r
474479
self.sigma2 = sigma2
475480

481+
self.full_loss = full_data_model.loss
482+
self.full_log_likelihood = full_data_model.log_likelihood
476483
self.batch_probs = batch_model.probs
477484
self.batch_log_probs = batch_model.log_probs
478485
self.batch_log_likelihood = batch_model.norm_log_likelihood
479486

480487
self.sample_selection = sample_selection
481488
self.full_data_model = full_data_model
482489

483-
self.full_loss = full_data_loss
484490
self.hessians = full_data_model.hessian
485-
self.fisher_inv = fisher_inv
491+
self.fisher_inv = full_data_fisher_inv
486492

487493
self.idx_nonconverged = idx_nonconverged
488494

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,46 +25,27 @@ class TrainingStrategy(Enum):
2525
DEFAULT = [
2626
{
2727
"convergence_criteria": "all_converged_ll",
28-
"stopping_criteria": 1e-8,
28+
"stopping_criteria": 1e-6,
2929
"use_batching": False,
30-
"optim_algo": "Newton",
30+
"optim_algo": "irls",
3131
},
3232
]
3333
QUICK = [
3434
{
3535
"convergence_criteria": "all_converged_ll",
36-
"stopping_criteria": 1e-6,
37-
"use_batching": False,
38-
"optim_algo": "Newton",
39-
},
40-
]
41-
PRE_INITIALIZED = [
42-
{
43-
"convergence_criteria": "scaled_moving_average",
44-
"stopping_criteria": 1e-10,
45-
"loss_window_size": 10,
36+
"stopping_criteria": 1e-4,
4637
"use_batching": False,
47-
"optim_algo": "newton",
38+
"optim_algo": "irls",
4839
},
4940
]
50-
CONSTRAINED = [ # Should not contain newton-rhapson right now.
41+
EXACT = [
5142
{
52-
"learning_rate": 0.5,
5343
"convergence_criteria": "all_converged_ll",
5444
"stopping_criteria": 1e-8,
55-
"loss_window_size": 10,
5645
"use_batching": False,
57-
"optim_algo": "ADAM",
46+
"optim_algo": "irls",
5847
},
5948
]
60-
CONTINUOUS = [
61-
{
62-
"convergence_criteria": "all_converged_ll",
63-
"stopping_criteria": 1e-8,
64-
"use_batching": False,
65-
"optim_algo": "Newton",
66-
}
67-
]
6849

6950
def __init__(
7051
self,

0 commit comments

Comments
 (0)