22import logging
33from typing import Union
44
5- import tensorflow as tf
6-
75import numpy as np
6+ import tensorflow as tf
7+ import xarray as xr
88
99try :
1010 import anndata
@@ -46,13 +46,36 @@ class FullDataModelGraphGLM:
4646 loss : tf .Tensor
4747
4848 jac : tf .Tensor
49- neg_jac : tf .Tensor
50- neg_jac_train : tf .Tensor
51- neg_jac_train_a : tf .Tensor
52- neg_jac_train_b : tf .Tensor
53- hessian : tf .Tensor
54- neg_hessian : tf .Tensor
55- neg_hessian_train : tf .Tensor
49+ jac_train : tf .Tensor
50+
51+ hessians : tf .Tensor
52+ hessians_train : tf .Tensor
53+
54+ fim : tf .Tensor
55+ fim_train : tf .Tensor
56+
57+ noise_model : str
58+
59+
60+ class BatchedDataModelGraphGLM :
61+ """
62+ Computational graph to evaluate model on batches of data set.
63+
64+ The model metrics of a batch which can be collected are:
65+
66+ - The model likelihood (cost function value).
67+ - Model Jacobian matrix for trained parameters (for training).
68+ - Model Hessian matrix for trained parameters (for training).
69+ - Model Fisher information matrix for trained parameters (for training).
70+ """
71+ log_likelihood : tf .Tensor
72+ norm_log_likelihood : tf .Tensor
73+ norm_neg_log_likelihood : tf .Tensor
74+ loss : tf .Tensor
75+
76+ jac_train : tf .Tensor
77+ hessians_train : tf .Tensor
78+ fim_train : tf .Tensor
5679
5780 noise_model : str
5881
@@ -124,7 +147,7 @@ def __init__(
124147 self .gradients_batch = gradients_batch
125148
126149 def gradients_full_byfeature (self ):
127- gradients_full_all = tf .transpose (self .full_data_model .neg_jac_train )
150+ gradients_full_all = tf .transpose (self .full_data_model .jac_train . neg_jac )
128151 gradients_full = tf .concat ([
129152 # tf.gradients(full_data_model.norm_neg_log_likelihood,
130153 # model_vars.params_by_gene[i])[0]
@@ -137,7 +160,7 @@ def gradients_full_byfeature(self):
137160 self .gradients_full_raw = gradients_full
138161
139162 def gradients_batched_byfeature (self ):
140- gradients_batch_all = tf .transpose (self .batch_jac .neg_jac )
163+ gradients_batch_all = tf .transpose (self .batched_data_model . jac_train .neg_jac )
141164 gradients_batch = tf .concat ([
142165 # tf.gradients(batch_model.norm_neg_log_likelihood,
143166 # model_vars.params_by_gene[i])[0]
@@ -150,11 +173,11 @@ def gradients_batched_byfeature(self):
150173 self .gradients_batch_raw = gradients_batch
151174
152175 def gradients_full_global (self ):
153- gradients_full = tf .transpose (self .full_data_model .neg_jac_train )
176+ gradients_full = tf .transpose (self .full_data_model .jac_train . neg_jac )
154177 self .gradients_full_raw = gradients_full
155178
156179 def gradients_batched_global (self ):
157- gradients_batch = tf .transpose (self .batch_jac .neg_jac )
180+ gradients_batch = tf .transpose (self .batched_data_model . jac_train .neg_jac )
158181 self .gradients_batch_raw = gradients_batch
159182
160183
@@ -193,10 +216,10 @@ def __init__(
193216 ):
194217 if provide_optimizers ["nr" ]:
195218 nr_update_full_raw , nr_update_batched_raw = self .build_updates (
196- full_lhs = self .full_data_model .neg_hessian_train ,
197- batched_lhs = self .batch_hessians .neg_hessian ,
198- full_rhs = self .full_data_model .neg_jac_train ,
199- batched_rhs = self .batch_jac .neg_jac ,
219+ full_lhs = self .full_data_model .hessians_train . neg_hessian ,
220+ batched_lhs = self .batched_data_model . hessians_train .neg_hessian ,
221+ full_rhs = self .full_data_model .jac_train . neg_jac ,
222+ batched_rhs = self .batched_data_model . jac_train .neg_jac ,
200223 termination_type = termination_type ,
201224 psd = False
202225 )
@@ -219,9 +242,9 @@ def __init__(
219242 # passed here with psd=True.
220243 irls_update_a_full , irls_update_a_batched = self .build_updates (
221244 full_lhs = self .full_data_model .fim_train .fim_a ,
222- batched_lhs = self .batch_fim .fim_a ,
223- full_rhs = self .full_data_model .neg_jac_train_a ,
224- batched_rhs = self .batch_jac .neg_jac_a ,
245+ batched_lhs = self .batched_data_model . fim_train .fim_a ,
246+ full_rhs = self .full_data_model .jac_train . neg_jac_a ,
247+ batched_rhs = self .batched_data_model . jac_train .neg_jac_a ,
225248 termination_type = termination_type ,
226249 psd = True
227250 )
@@ -232,11 +255,11 @@ def __init__(
232255 if train_r :
233256 irls_update_b_full , irls_update_b_batched = self .build_updates (
234257 full_lhs = self .full_data_model .fim_train .fim_b ,
235- batched_lhs = self .batch_fim .fim_b ,
236- full_rhs = self .full_data_model .neg_jac_train_b ,
237- batched_rhs = self .batch_jac .neg_jac_b ,
258+ batched_lhs = self .batched_data_model . fim_train .fim_b ,
259+ full_rhs = self .full_data_model .jac_train . neg_jac_b ,
260+ batched_rhs = self .batched_data_model . jac_train .neg_jac_b ,
238261 termination_type = termination_type ,
239- psd = False
262+ psd = True # TODO proove
240263 )
241264 else :
242265 irls_update_b_full = None
@@ -608,8 +631,11 @@ def __init__(
608631 num_design_scale_params ,
609632 num_loc_params ,
610633 num_scale_params ,
611- graph : tf .Graph = None ,
612- batch_size : int = None ,
634+ graph : tf .Graph ,
635+ batch_size : int ,
636+ constraints_loc : xr .DataArray ,
637+ constraints_scale : xr .DataArray ,
638+ dtype
613639 ):
614640 """
615641
@@ -622,6 +648,16 @@ def __init__(
622648 :param num_design_scale_params: int
623649 Number of parameters per feature in scale model.
624650 :param graph: tf.Graph
651+ :param constraints_loc: tensor (all parameters x dependent parameters)
652+ Tensor that encodes how complete parameter set which includes dependent
653+ parameters arises from indepedent parameters: all = <constraints, indep>.
654+ This tensor describes this relation for the mean model.
655+ This form of constraints is used in vector generalized linear models (VGLMs).
656+ :param constraints_scale: tensor (all parameters x dependent parameters)
657+ Tensor that encodes how complete parameter set which includes dependent
658+ parameters arises from indepedent parameters: all = <constraints, indep>.
659+ This tensor describes this relation for the dispersion model.
660+ This form of constraints is used in vector generalized linear models (VGLMs).
625661 """
626662 TFEstimatorGraph .__init__ (
627663 self = self ,
@@ -636,19 +672,29 @@ def __init__(
636672 self .num_scale_params = num_scale_params
637673 self .batch_size = batch_size
638674
675+ self .constraints_loc = self ._set_constraints (
676+ constraints = constraints_loc ,
677+ dtype = dtype
678+ )
679+ self .constraints_scale = self ._set_constraints (
680+ constraints = constraints_scale ,
681+ dtype = dtype
682+ )
683+
684+ self .learning_rate = tf .placeholder (dtype , shape = (), name = "learning_rate" )
685+
639686 def _set_constraints (
640687 self ,
641688 constraints ,
642- design ,
643689 dtype
644690 ):
645691 if constraints is None :
646692 return tf .eye (
647- num_rows = tf .constant (design . shape [ 1 ] , shape = (), dtype = "int32" ),
693+ num_rows = tf .constant (self . num_design_loc_params , shape = (), dtype = "int32" ),
648694 dtype = dtype
649695 )
650696 else :
651- assert constraints .shape [0 ] == design . shape [ 1 ] , "constraint dimension mismatch"
697+ assert constraints .shape [0 ] == self . num_design_loc_params , "constraint dimension mismatch"
652698 return tf .cast (constraints , dtype = dtype )
653699
654700 @abc .abstractmethod
0 commit comments