@@ -84,27 +84,26 @@ def _coef_invariant_b(
8484 Value of mean model by observation and feature.
8585 :param r: tf.tensor observations x features
8686 Value of dispersion model by observation and feature.
87- :param dtype: dtype
8887 :return const: tf.tensor observations x features
8988 Coefficient invariant terms of hessian of
9089 given observations and features.
9190 """
92- scalar_one = tf .constant (1 , shape = [ 1 , 1 ] , dtype = X .dtype )
91+ scalar_one = tf .constant (1 , shape = () , dtype = X .dtype )
9392 # Pre-define sub-graphs that are used multiple times:
94- r_plus_mu = tf . add ( r , mu )
95- r_plus_x = tf . add ( r , X )
93+ r_plus_mu = r + mu
94+ r_plus_x = r + X
9695 # Define graphs for individual terms of constant term of hessian:
9796 const1 = tf .subtract (
9897 tf .math .digamma (x = r_plus_x ),
9998 tf .math .digamma (x = r )
10099 )
101- const2 = tf .negative (tf . divide ( r_plus_x , r_plus_mu ) )
100+ const2 = tf .negative (r_plus_x / r_plus_mu )
102101 const3 = tf .add (
103102 tf .log (r ),
104- tf . subtract ( scalar_one , tf .log (r_plus_mu ) )
103+ scalar_one - tf .log (r_plus_mu )
105104 )
106105 const = tf .add_n ([const1 , const2 , const3 ]) # [observations, features]
107- const = tf . multiply ( r , const )
106+ const = r * const
108107 return const
109108
110109
@@ -159,9 +158,9 @@ def __init__(
159158 Whether an iterator or a tensor (single yield of an iterator) is given
160159 in
161160 """
162- if constraints_loc != None and mode != "tf" :
161+ if constraints_loc is not None and mode != "tf" :
163162 raise ValueError ("closed form hessian does not work if constraints_loc is not None" )
164- if constraints_scale != None and mode != "tf" :
163+ if constraints_scale is not None and mode != "tf" :
165164 raise ValueError ("closed form hessian does not work if constraints_scale is not None" )
166165
167166 if mode == "analytic" :
@@ -378,15 +377,15 @@ def _red(prev, cur):
378377 p_shape_a = model_vars .a .shape [0 ]
379378 p_shape_b = model_vars .b .shape [0 ]
380379
381- if iterator == True and batch_model is None :
380+ if iterator == True and batch_model is None :
382381 J = op_utils .map_reduce (
383382 last_elem = tf .gather (sample_indices , tf .size (sample_indices ) - 1 ),
384383 data = batched_data ,
385384 map_fn = _assemble_bybatch ,
386385 reduce_fn = _red ,
387386 parallel_iterations = pkg_constants .TF_LOOP_PARALLEL_ITERATIONS
388387 )
389- elif iterator == False and batch_model is None :
388+ elif iterator == False and batch_model is None :
390389 J = _assemble_bybatch (
391390 idx = sample_indices ,
392391 data = batched_data
0 commit comments