@@ -130,7 +130,7 @@ def map_model(idx, data) -> BasicModelGraph:
130130 else :
131131 hessians_train = hessians_full
132132
133- fim_train = FIM (
133+ fim_full = FIM (
134134 batched_data = batched_data ,
135135 sample_indices = sample_indices ,
136136 constraints_loc = constraints_loc ,
@@ -139,10 +139,27 @@ def map_model(idx, data) -> BasicModelGraph:
139139 mode = pkg_constants .HESSIAN_MODE ,
140140 noise_model = noise_model ,
141141 iterator = True ,
142- update_a = train_a ,
143- update_b = train_b ,
142+ update_a = True ,
143+ update_b = True ,
144144 dtype = dtype
145145 )
146+ # Fisher information matrix of submodel which is to be trained.
147+ if not train_a or not train_b :
148+ fim_train = FIM (
149+ batched_data = batched_data ,
150+ sample_indices = sample_indices ,
151+ constraints_loc = constraints_loc ,
152+ constraints_scale = constraints_scale ,
153+ model_vars = model_vars ,
154+ mode = pkg_constants .HESSIAN_MODE ,
155+ noise_model = noise_model ,
156+ iterator = True ,
157+ update_a = train_a ,
158+ update_b = train_b ,
159+ dtype = dtype
160+ )
161+ else :
162+ fim_train = fim_full
146163
147164 with tf .name_scope ("jacobians" ):
148165 # Jacobian of full model for reporting.
@@ -210,7 +227,8 @@ def map_model(idx, data) -> BasicModelGraph:
210227 self .neg_hessian = hessians_full .neg_hessian
211228 self .neg_hessian_train = hessians_train .neg_hessian
212229
213- self .fim = fim_train
230+ self .fim_full = fim_full
231+ self .fim_train = fim_train
214232
215233
216234class EstimatorGraphAll (EstimatorGraphGLM ):
@@ -439,20 +457,7 @@ def __init__(
439457 noise_model = noise_model ,
440458 dtype = dtype
441459 )
442- full_data_loss = full_data_model .loss
443- fisher_inv = op_utils .pinv (full_data_model .neg_hessian )
444-
445- # with tf.name_scope("hessian_diagonal"):
446- # hessian_diagonal = [
447- # tf.map_fn(
448- # # elems=tf.transpose(hess, perm=[2, 0, 1]),
449- # elems=hess,
450- # fn=tf.diag_part,
451- # parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
452- # )
453- # for hess in full_data_model.hessians
454- # ]
455- # fisher_a, fisher_b = hessian_diagonal
460+ full_data_fisher_inv = op_utils .pinv (full_data_model .neg_hessian ) # TODO switch for fim
456461
457462 mu = full_data_model .mu
458463 r = full_data_model .r
@@ -473,16 +478,17 @@ def __init__(
473478 self .r = r
474479 self .sigma2 = sigma2
475480
481+ self .full_loss = full_data_model .loss
482+ self .full_log_likelihood = full_data_model .log_likelihood
476483 self .batch_probs = batch_model .probs
477484 self .batch_log_probs = batch_model .log_probs
478485 self .batch_log_likelihood = batch_model .norm_log_likelihood
479486
480487 self .sample_selection = sample_selection
481488 self .full_data_model = full_data_model
482489
483- self .full_loss = full_data_loss
484490 self .hessians = full_data_model .hessian
485- self .fisher_inv = fisher_inv
491+ self .fisher_inv = full_data_fisher_inv
486492
487493 self .idx_nonconverged = idx_nonconverged
488494
0 commit comments