@@ -175,7 +175,6 @@ class Hessians:
175175 def __init__ (
176176 self ,
177177 batched_data : tf .data .Dataset ,
178- singleobs_data : tf .data .Dataset ,
179178 sample_indices : tf .Tensor ,
180179 constraints_loc ,
181180 constraints_scale ,
@@ -188,7 +187,6 @@ def __init__(
188187
189188 :param batched_data:
190189 Dataset iterator over mini-batches of data (used for training) or tf.Tensors of mini-batch.
191- :param singleobs_data: Dataset iterator over single observation batches of data.
192190 :param sample_indices: Indices of samples to be used.
193191 :param constraints_loc: Constraints for location model.
194192 Array with constraints in rows and model parameters in columns.
@@ -218,29 +216,14 @@ def __init__(
218216 evaluation of the hessian via the tf.hessian function,
219217 which is done by feature for implementation reasons.
220218 :param iterator: bool
221- Whether an iterator or a tensor (single yield of an iterator) is given
222- in
219+ Whether batched_data is an iterator or a tensor (such as single yield of an iterator).
223220 """
224- if constraints_loc != None and mode != "tf" :
221+ if constraints_loc is not None and mode != "tf" :
225222 raise ValueError ("closed form hessian does not work if constraints_loc is not None" )
226- if constraints_scale != None and mode != "tf" :
223+ if constraints_scale is not None and mode != "tf" :
227224 raise ValueError ("closed form hessian does not work if constraints_scale is not None" )
228225
229- if mode == "obs" :
230- logger .info ("Performance warning for hessian mode: " +
231- "obs_batched is strongly recommended as an alternative to obs." )
232- self .hessian = self .byobs (
233- batched_data = singleobs_data ,
234- sample_indices = sample_indices ,
235- constraints_loc = constraints_loc ,
236- constraints_scale = constraints_scale ,
237- model_vars = model_vars ,
238- batched = False ,
239- iterator = iterator ,
240- dtype = dtype
241- )
242- self .neg_hessian = tf .negative (self .hessian )
243- elif mode == "obs_batched" :
226+ if mode == "obs_batched" :
244227 self .hessian = self .byobs (
245228 batched_data = batched_data ,
246229 sample_indices = sample_indices ,
@@ -259,6 +242,7 @@ def __init__(
259242 constraints_loc = constraints_loc ,
260243 constraints_scale = constraints_scale ,
261244 model_vars = model_vars ,
245+ iterator = iterator ,
262246 dtype = dtype
263247 )
264248 self .neg_hessian = tf .negative (self .hessian )
@@ -272,6 +256,7 @@ def __init__(
272256 constraints_loc = constraints_loc ,
273257 constraints_scale = constraints_scale ,
274258 model_vars = model_vars ,
259+ iterator = iterator ,
275260 dtype = dtype
276261 )
277262 self .hessian = tf .negative (self .neg_hessian )
@@ -517,8 +502,8 @@ def _red(prev, cur):
517502 return tf .add (prev , cur )
518503
519504 params = model_vars .params
520- p_shape_a = model_vars .a .shape [0 ]
521- p_shape_b = model_vars .b .shape [0 ]
505+ p_shape_a = model_vars .a_var .shape [0 ] # This has to be _var to work with constraints.
506+ p_shape_b = model_vars .b_var .shape [0 ] # This has to be _var to work with constraints.
522507
523508 if iterator :
524509 H = op_utils .map_reduce (
@@ -542,6 +527,7 @@ def byfeature(
542527 constraints_loc ,
543528 constraints_scale ,
544529 model_vars : ModelVars ,
530+ iterator ,
545531 dtype
546532 ):
547533 """
@@ -685,18 +671,24 @@ def _red(prev, cur):
685671 return [tf .add (p , c ) for p , c in zip (prev , cur )]
686672
687673 params = model_vars .params
688- p_shape_a = model_vars .a .shape [0 ]
689- p_shape_b = model_vars .b .shape [0 ]
690-
691- H = op_utils .map_reduce (
692- last_elem = tf .gather (sample_indices , tf .size (sample_indices ) - 1 ),
693- data = batched_data ,
694- map_fn = _map ,
695- reduce_fn = _red ,
696- parallel_iterations = 1 ,
697- )
698- H = H [0 ]
699- return H
674+ p_shape_a = model_vars .a_var .shape [0 ] # This has to be _var to work with constraints.
675+ p_shape_b = model_vars .b_var .shape [0 ] # This has to be _var to work with constraints.
676+
677+ if iterator :
678+ H = op_utils .map_reduce (
679+ last_elem = tf .gather (sample_indices , tf .size (sample_indices ) - 1 ),
680+ data = batched_data ,
681+ map_fn = _map ,
682+ reduce_fn = _red ,
683+ parallel_iterations = 1
684+ )
685+ else :
686+ H = _map (
687+ idx = sample_indices ,
688+ data = batched_data
689+ )
690+
691+ return H [0 ]
700692
701693 def tf_byfeature (
702694 self ,
@@ -705,6 +697,7 @@ def tf_byfeature(
705697 constraints_loc ,
706698 constraints_scale ,
707699 model_vars : ModelVars ,
700+ iterator ,
708701 dtype
709702 ) -> List [tf .Tensor ]:
710703 """
@@ -804,20 +797,27 @@ def _map(idx, data):
804797 constraints_loc = constraints_loc ,
805798 constraints_scale = constraints_scale ,
806799 params = model_vars .params ,
807- p_shape_a = model_vars .a .shape [0 ],
808- p_shape_b = model_vars .b .shape [0 ],
800+ p_shape_a = model_vars .a_var .shape [0 ], # This has to be _var to work with constraints.
801+ p_shape_b = model_vars .b_var .shape [0 ], # This has to be _var to work with constraints.
809802 dtype = dtype ,
810803 size_factors = size_factors
811804 )
812805
813806 def _red (prev , cur ):
814807 return [tf .add (p , c ) for p , c in zip (prev , cur )]
815808
816- H = op_utils .map_reduce (
817- last_elem = tf .gather (sample_indices , tf .size (sample_indices ) - 1 ),
818- data = batched_data ,
819- map_fn = _map ,
820- reduce_fn = _red ,
821- parallel_iterations = 1 ,
822- )
809+ if iterator :
810+ H = op_utils .map_reduce (
811+ last_elem = tf .gather (sample_indices , tf .size (sample_indices ) - 1 ),
812+ data = batched_data ,
813+ map_fn = _map ,
814+ reduce_fn = _red ,
815+ parallel_iterations = 1
816+ )
817+ else :
818+ H = _map (
819+ idx = sample_indices ,
820+ data = batched_data
821+ )
822+
823823 return H [0 ]
0 commit comments