Skip to content

Commit c9f8051

Browse files
cleaned up batched data model in EstimatorGraph by factoring out
also set default cholesky for FIM_b to True.
1 parent 42fbed3 commit c9f8051

File tree

8 files changed

+307
-213
lines changed

8 files changed

+307
-213
lines changed

batchglm/models/base_glm/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ESTIMATOR_PARAMS = MODEL_PARAMS.copy()
1212
ESTIMATOR_PARAMS.update({
1313
"loss": (),
14-
"full_log_likelihood": ("features",),
14+
"log_likelihood": ("features",),
1515
"gradient": ("features",),
1616
"hessians": ("features", "delta_var0", "delta_var1"),
1717
"fisher_inv": ("features", "delta_var0", "delta_var1"),
@@ -29,7 +29,7 @@ def __init__(self):
2929

3030
@property
3131
def log_likelihood(self):
32-
return self.params["full_log_likelihood"]
32+
return self.params["log_likelihood"]
3333

3434
@property
3535
def gradient(self):

batchglm/models/glm_nb/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, estim: AbstractEstimator):
2323
# causes evaluation of the properties that have not been computed during
2424
# training, such as the hessian.
2525
params = estim.to_xarray(
26-
["a_var", "b_var", "loss", "full_log_likelihood", "gradient", "hessians", "fisher_inv"],
26+
["a_var", "b_var", "loss", "log_likelihood", "gradient", "hessians", "fisher_inv"],
2727
coords=input_data.data
2828
)
2929

batchglm/train/tf/base_glm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .estimator_graph import GradientGraphGLM, NewtonGraphGLM, TrainerGraphGLM, EstimatorGraphGLM, FullDataModelGraphGLM
1+
from .estimator_graph import GradientGraphGLM, NewtonGraphGLM, TrainerGraphGLM, EstimatorGraphGLM, FullDataModelGraphGLM, BatchedDataModelGraphGLM
22
from .hessians import HessiansGLM
33
from .fim import FIMGLM
44
from .jacobians import JacobiansGLM

batchglm/train/tf/base_glm/estimator_graph.py

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import logging
33
from typing import Union
44

5-
import tensorflow as tf
6-
75
import numpy as np
6+
import tensorflow as tf
7+
import xarray as xr
88

99
try:
1010
import anndata
@@ -46,13 +46,36 @@ class FullDataModelGraphGLM:
4646
loss: tf.Tensor
4747

4848
jac: tf.Tensor
49-
neg_jac: tf.Tensor
50-
neg_jac_train: tf.Tensor
51-
neg_jac_train_a: tf.Tensor
52-
neg_jac_train_b: tf.Tensor
53-
hessian: tf.Tensor
54-
neg_hessian: tf.Tensor
55-
neg_hessian_train: tf.Tensor
49+
jac_train: tf.Tensor
50+
51+
hessians: tf.Tensor
52+
hessians_train: tf.Tensor
53+
54+
fim: tf.Tensor
55+
fim_train: tf.Tensor
56+
57+
noise_model: str
58+
59+
60+
class BatchedDataModelGraphGLM:
61+
"""
62+
Computational graph to evaluate model on batches of data set.
63+
64+
The model metrics of a batch which can be collected are:
65+
66+
- The model likelihood (cost function value).
67+
- Model Jacobian matrix for trained parameters (for training).
68+
- Model Hessian matrix for trained parameters (for training).
69+
- Model Fisher information matrix for trained parameters (for training).
70+
"""
71+
log_likelihood: tf.Tensor
72+
norm_log_likelihood: tf.Tensor
73+
norm_neg_log_likelihood: tf.Tensor
74+
loss: tf.Tensor
75+
76+
jac_train: tf.Tensor
77+
hessians_train: tf.Tensor
78+
fim_train: tf.Tensor
5679

5780
noise_model: str
5881

@@ -124,7 +147,7 @@ def __init__(
124147
self.gradients_batch = gradients_batch
125148

126149
def gradients_full_byfeature(self):
127-
gradients_full_all = tf.transpose(self.full_data_model.neg_jac_train)
150+
gradients_full_all = tf.transpose(self.full_data_model.jac_train.neg_jac)
128151
gradients_full = tf.concat([
129152
# tf.gradients(full_data_model.norm_neg_log_likelihood,
130153
# model_vars.params_by_gene[i])[0]
@@ -137,7 +160,7 @@ def gradients_full_byfeature(self):
137160
self.gradients_full_raw = gradients_full
138161

139162
def gradients_batched_byfeature(self):
140-
gradients_batch_all = tf.transpose(self.batch_jac.neg_jac)
163+
gradients_batch_all = tf.transpose(self.batched_data_model.jac_train.neg_jac)
141164
gradients_batch = tf.concat([
142165
# tf.gradients(batch_model.norm_neg_log_likelihood,
143166
# model_vars.params_by_gene[i])[0]
@@ -150,11 +173,11 @@ def gradients_batched_byfeature(self):
150173
self.gradients_batch_raw = gradients_batch
151174

152175
def gradients_full_global(self):
153-
gradients_full = tf.transpose(self.full_data_model.neg_jac_train)
176+
gradients_full = tf.transpose(self.full_data_model.jac_train.neg_jac)
154177
self.gradients_full_raw = gradients_full
155178

156179
def gradients_batched_global(self):
157-
gradients_batch = tf.transpose(self.batch_jac.neg_jac)
180+
gradients_batch = tf.transpose(self.batched_data_model.jac_train.neg_jac)
158181
self.gradients_batch_raw = gradients_batch
159182

160183

@@ -193,10 +216,10 @@ def __init__(
193216
):
194217
if provide_optimizers["nr"]:
195218
nr_update_full_raw, nr_update_batched_raw = self.build_updates(
196-
full_lhs=self.full_data_model.neg_hessian_train,
197-
batched_lhs=self.batch_hessians.neg_hessian,
198-
full_rhs=self.full_data_model.neg_jac_train,
199-
batched_rhs=self.batch_jac.neg_jac,
219+
full_lhs=self.full_data_model.hessians_train.neg_hessian,
220+
batched_lhs=self.batched_data_model.hessians_train.neg_hessian,
221+
full_rhs=self.full_data_model.jac_train.neg_jac,
222+
batched_rhs=self.batched_data_model.jac_train.neg_jac,
200223
termination_type=termination_type,
201224
psd=False
202225
)
@@ -219,9 +242,9 @@ def __init__(
219242
# passed here with psd=True.
220243
irls_update_a_full, irls_update_a_batched = self.build_updates(
221244
full_lhs=self.full_data_model.fim_train.fim_a,
222-
batched_lhs=self.batch_fim.fim_a,
223-
full_rhs=self.full_data_model.neg_jac_train_a,
224-
batched_rhs=self.batch_jac.neg_jac_a,
245+
batched_lhs=self.batched_data_model.fim_train.fim_a,
246+
full_rhs=self.full_data_model.jac_train.neg_jac_a,
247+
batched_rhs=self.batched_data_model.jac_train.neg_jac_a,
225248
termination_type=termination_type,
226249
psd=True
227250
)
@@ -232,11 +255,11 @@ def __init__(
232255
if train_r:
233256
irls_update_b_full, irls_update_b_batched = self.build_updates(
234257
full_lhs=self.full_data_model.fim_train.fim_b,
235-
batched_lhs=self.batch_fim.fim_b,
236-
full_rhs=self.full_data_model.neg_jac_train_b,
237-
batched_rhs=self.batch_jac.neg_jac_b,
258+
batched_lhs=self.batched_data_model.fim_train.fim_b,
259+
full_rhs=self.full_data_model.jac_train.neg_jac_b,
260+
batched_rhs=self.batched_data_model.jac_train.neg_jac_b,
238261
termination_type=termination_type,
239-
psd=False
262+
psd=True # TODO proove
240263
)
241264
else:
242265
irls_update_b_full = None
@@ -608,8 +631,11 @@ def __init__(
608631
num_design_scale_params,
609632
num_loc_params,
610633
num_scale_params,
611-
graph: tf.Graph = None,
612-
batch_size: int = None,
634+
graph: tf.Graph,
635+
batch_size: int,
636+
constraints_loc: xr.DataArray,
637+
constraints_scale: xr.DataArray,
638+
dtype
613639
):
614640
"""
615641
@@ -622,6 +648,16 @@ def __init__(
622648
:param num_design_scale_params: int
623649
Number of parameters per feature in scale model.
624650
:param graph: tf.Graph
651+
:param constraints_loc: tensor (all parameters x dependent parameters)
652+
Tensor that encodes how complete parameter set which includes dependent
653+
parameters arises from indepedent parameters: all = <constraints, indep>.
654+
This tensor describes this relation for the mean model.
655+
This form of constraints is used in vector generalized linear models (VGLMs).
656+
:param constraints_scale: tensor (all parameters x dependent parameters)
657+
Tensor that encodes how complete parameter set which includes dependent
658+
parameters arises from indepedent parameters: all = <constraints, indep>.
659+
This tensor describes this relation for the dispersion model.
660+
This form of constraints is used in vector generalized linear models (VGLMs).
625661
"""
626662
TFEstimatorGraph.__init__(
627663
self=self,
@@ -636,19 +672,29 @@ def __init__(
636672
self.num_scale_params = num_scale_params
637673
self.batch_size = batch_size
638674

675+
self.constraints_loc = self._set_constraints(
676+
constraints=constraints_loc,
677+
dtype=dtype
678+
)
679+
self.constraints_scale = self._set_constraints(
680+
constraints=constraints_scale,
681+
dtype=dtype
682+
)
683+
684+
self.learning_rate = tf.placeholder(dtype, shape=(), name="learning_rate")
685+
639686
def _set_constraints(
640687
self,
641688
constraints,
642-
design,
643689
dtype
644690
):
645691
if constraints is None:
646692
return tf.eye(
647-
num_rows=tf.constant(design.shape[1], shape=(), dtype="int32"),
693+
num_rows=tf.constant(self.num_design_loc_params, shape=(), dtype="int32"),
648694
dtype=dtype
649695
)
650696
else:
651-
assert constraints.shape[0] == design.shape[1], "constraint dimension mismatch"
697+
assert constraints.shape[0] == self.num_design_loc_params, "constraint dimension mismatch"
652698
return tf.cast(constraints, dtype=dtype)
653699

654700
@abc.abstractmethod

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,10 @@ def train(self, *args,
314314

315315
if train_mu or train_r:
316316
if use_batching:
317-
loss = self.model.loss
317+
loss = self.model.batched_data_model.loss
318318
train_op = self.model.trainer_batch.train_op_by_name(optim_algo)
319319
else:
320-
loss = self.model.full_loss
320+
loss = self.model.full_data_model.loss
321321
train_op = self.model.trainer_full.train_op_by_name(optim_algo)
322322

323323
super().train(*args,
@@ -362,23 +362,15 @@ def b_var(self):
362362

363363
@property
364364
def loss(self):
365-
return self.to_xarray("full_loss")
366-
367-
@property
368-
def batch_loss(self):
369365
return self.to_xarray("loss")
370366

371367
@property
372368
def log_likelihood(self):
373-
return self.to_xarray("full_log_likelihood", coords=self.input_data.data.coords)
374-
375-
@property
376-
def batch_gradient(self):
377-
return self.to_xarray("gradient", coords=self.input_data.data.coords)
369+
return self.to_xarray("log_likelihood", coords=self.input_data.data.coords)
378370

379371
@property
380372
def gradient(self):
381-
return self.to_xarray("full_gradient", coords=self.input_data.data.coords)
373+
return self.to_xarray("gradient", coords=self.input_data.data.coords)
382374

383375
@property
384376
def hessians(self):

0 commit comments

Comments
 (0)