1919from .base import param_bounds , tf_clip_param , np_clip_param , apply_constraints
2020
2121from .external import AbstractEstimator , XArrayEstimatorStore , InputData , Model , MonitoredTFEstimator , TFEstimatorGraph
22- from .external import nb_utils , train_utils , op_utils , rand_utils
22+ from .external import nb_utils , train_utils , op_utils , rand_utils , data_utils
2323from .external import pkg_constants
2424from .hessians import Hessians
25+ from .jacobians import Jacobians
2526
2627logger = logging .getLogger (__name__ )
2728
@@ -91,6 +92,20 @@ def map_model(idx, data) -> BasicModelGraph:
9192 constraints_scale = constraints_scale ,
9293 model_vars = model_vars ,
9394 mode = pkg_constants .HESSIAN_MODE ,
95+ iterator = True ,
96+ dtype = dtype
97+ )
98+
99+ with tf .name_scope ("jacobians" ):
100+ jacobians = Jacobians (
101+ batched_data = batched_data ,
102+ sample_indices = sample_indices ,
103+ batch_model = None ,
104+ constraints_loc = constraints_loc ,
105+ constraints_scale = constraints_scale ,
106+ model_vars = model_vars ,
107+ mode = pkg_constants .JACOBIAN_MODE ,
108+ iterator = True ,
94109 dtype = dtype
95110 )
96111
@@ -121,6 +136,8 @@ def map_model(idx, data) -> BasicModelGraph:
121136 self .norm_neg_log_likelihood = norm_neg_log_likelihood
122137 self .loss = loss
123138
139+ self .jac = jacobians .jac
140+ self .neg_jac = jacobians .neg_jac
124141 self .hessian = hessians .hessian
125142 self .neg_hessian = hessians .neg_hessian
126143
@@ -235,7 +252,20 @@ def __init__(
235252 # use the mean loss to keep a constant learning rate independently of the batch size
236253 batch_loss = batch_model .loss
237254
238- # Define the hessian on the batched model:
255+ # Define the jacobian on the batched model for newton-rhapson:
256+ batch_jac = Jacobians (
257+ batched_data = batch_data ,
258+ sample_indices = batch_sample_index ,
259+ batch_model = batch_model ,
260+ constraints_loc = constraints_loc ,
261+ constraints_scale = constraints_scale ,
262+ model_vars = model_vars ,
263+ mode = "analytic" ,
264+ iterator = False ,
265+ dtype = dtype
266+ )
267+
268+ # Define the hessian on the batched model for newton-rhapson:
239269 batch_hessians = Hessians (
240270 batched_data = batch_data ,
241271 singleobs_data = None ,
@@ -366,21 +396,23 @@ def __init__(
366396 name = "full_data_trainers_b_only"
367397 )
368398 with tf .name_scope ("full_gradient" ):
369- full_gradient = full_data_trainers .gradient [0 ][0 ]
370- full_gradient = tf .reduce_sum (tf .abs (full_gradient ), axis = 0 )
399+ #full_gradient = full_data_trainers.gradient[0][0]
400+ #full_gradient = tf.reduce_sum(tf.abs(full_gradient), axis=0)
401+ full_gradient = full_data_model .neg_jac
371402 # full_gradient = tf.add_n(
372403 # [tf.reduce_sum(tf.abs(grad), axis=0) for (grad, var) in full_data_trainers.gradient])
373404
374405 with tf .name_scope ("newton-raphson" ):
375406 # tf.gradients(- full_data_model.log_likelihood, [model_vars.a, model_vars.b])
376407 # Full data model:
377- param_grad_vec = tf .gradients (- full_data_model .log_likelihood , model_vars .params )[0 ]
378- param_grad_vec_t = tf .transpose (param_grad_vec )
408+ param_grad_vec = full_data_model .neg_jac
409+ #param_grad_vec = tf.gradients(- full_data_model.log_likelihood, model_vars.params)[0]
410+ #param_grad_vec_t = tf.transpose(param_grad_vec)
379411
380412 delta_t = tf .squeeze (tf .matrix_solve_ls (
381413 full_data_model .neg_hessian ,
382414 # (full_data_model.hessians + tf.transpose(full_data_model.hessians, perm=[0, 2, 1])) / 2, # don't need this with closed forms
383- tf .expand_dims (param_grad_vec_t , axis = - 1 ),
415+ tf .expand_dims (param_grad_vec , axis = - 1 ),
384416 fast = False
385417 ), axis = - 1 )
386418 delta = tf .transpose (delta_t )
@@ -392,13 +424,14 @@ def __init__(
392424 )
393425
394426 # Batched data model:
395- param_grad_vec_batched = tf .gradients (- batch_model .log_likelihood ,
396- model_vars .params )[0 ]
397- param_grad_vec_batched_t = tf .transpose (param_grad_vec_batched )
427+ param_grad_vec_batched = batch_jac .neg_jac
428+ #param_grad_vec_batched = tf.gradients(- batch_model.log_likelihood,
429+ # model_vars.params)[0]
430+ #param_grad_vec_batched_t = tf.transpose(param_grad_vec_batched)
398431
399432 delta_batched_t = tf .squeeze (tf .matrix_solve_ls (
400433 batch_hessians .neg_hessian ,
401- tf .expand_dims (param_grad_vec_batched_t , axis = - 1 ),
434+ tf .expand_dims (param_grad_vec_batched , axis = - 1 ),
402435 fast = False
403436 ), axis = - 1 )
404437 delta_batched = tf .transpose (delta_batched_t )
@@ -741,6 +774,9 @@ def __init__(
741774 shape = [input_data .num_observations , input_data .num_features ]
742775 )
743776
777+ groupwise_means = None # [groups, features]
778+ overall_means = None # [1, features]
779+ logger .debug (" * Initialize mean model" )
744780 if isinstance (init_a , str ):
745781 # Chose option if auto was chosen
746782 if init_a .lower () == "auto" :
@@ -766,12 +802,12 @@ def __init__(
766802 if size_factors_init is not None :
767803 X = np .divide (X , size_factors_init )
768804
769- mean = X .groupby ("group" ).mean (dim = "observations" ).values
805+ groupwise_means = X .groupby ("group" ).mean (dim = "observations" ).values
770806 # clipping
771- mean = np_clip_param (mean , "mu" )
807+ groupwise_means = np_clip_param (groupwise_means , "mu" )
772808 # mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
773809
774- a = np .log (mean )
810+ a = np .log (groupwise_means )
775811 if input_data .constraints_loc is not None :
776812 a_constraints = np .zeros ([input_data .constraints_loc .shape [0 ], a .shape [1 ]])
777813 # Add constraints (sum to zero) to value vector to remove structural unidentifiability.
@@ -781,7 +817,9 @@ def __init__(
781817 # inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
782818 # init_a = np.matmul(inv_design, a)
783819 #
784- # Better option: use least-squares solver to calculate a'
820+ # Use least-squares solver to calculate a':
821+ # This is faster and more accurate than using matrix inversion.
822+ logger .debug (" ** Solve lstsq problem" )
785823 a_prime = np .linalg .lstsq (unique_design_loc , a , rcond = None )
786824 init_a = a_prime [0 ]
787825 # stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
@@ -799,18 +837,19 @@ def __init__(
799837 except np .linalg .LinAlgError :
800838 logger .warning ("Closed form initialization failed!" )
801839 elif init_a .lower () == "standard" :
802- mean = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
840+ overall_means = input_data .X .mean (dim = "observations" ).values # directly calculate the mean
803841 # clipping
804- mean = np_clip_param (mean , "mu" )
842+ overall_means = np_clip_param (overall_means , "mu" )
805843 # mean = np.nextafter(0, 1, out=mean, where=mean == 0, dtype=mean.dtype)
806844
807845 init_a = np .zeros ([input_data .num_design_loc_params , input_data .num_features ])
808- init_a [0 , :] = np .log (mean )
846+ init_a [0 , :] = np .log (overall_means )
809847 self ._train_mu = True
810848
811849 logger .info ("Using standard initialization for mean" )
812850 logger .info ("Should train mu: %s" , self ._train_mu )
813851
852+ logger .debug (" * Initialize dispersion model" )
814853 if isinstance (init_b , str ):
815854 if init_b .lower () == "auto" :
816855 init_b = "closed_form"
@@ -837,14 +876,20 @@ def __init__(
837876 if input_data .size_factors is not None :
838877 X = np .divide (X , size_factors_init )
839878
840- Xdiff = X - np .exp (input_data .design_loc @ init_a )
879+ #Xdiff = X - np.exp(input_data.design_loc @ init_a)
880+ # Define xarray version of init so that Xdiff can be evaluated lazy by dask.
881+ init_a_xr = data_utils .xarray_from_data (init_a , dims = ("design_loc_params" , "features" ))
882+ init_a_xr .coords ["design_loc_params" ] = input_data .design_loc .coords ["design_loc_params" ]
883+ logger .debug (" ** Define Xdiff" )
884+ Xdiff = X - np .exp (input_data .design_loc .dot (init_a_xr ))
841885 variance = np .square (Xdiff ).groupby ("group" ).mean (dim = "observations" )
842886
843- group_mean = X .groupby ("group" ).mean (dim = "observations" )
844- denominator = np .fmax (variance - group_mean , 0 )
887+ if groupwise_means is None :
888+ groupwise_means = X .groupby ("group" ).mean (dim = "observations" )
889+ denominator = np .fmax (variance - groupwise_means , 0 )
845890 denominator = np .nextafter (0 , 1 , out = denominator .values , where = denominator == 0 ,
846891 dtype = denominator .dtype )
847- r = np .asarray (np .square (group_mean ) / denominator )
892+ r = np .asarray (np .square (groupwise_means ) / denominator )
848893 # clipping
849894 r = np_clip_param (r , "r" )
850895 # r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
@@ -860,7 +905,9 @@ def __init__(
860905 # inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
861906 # init_b = np.matmul(inv_design, b)
862907 #
863- # Better option: use least-squares solver to calculate b''
908+ # Use least-squares solver to calculate a':
909+ # This is faster and more accurate than using matrix inversion.
910+ logger .debug (" ** Solve lstsq problem" )
864911 b_prime = np .linalg .lstsq (unique_design_scale , b , rcond = None )
865912 init_b = b_prime [0 ]
866913
@@ -981,6 +1028,7 @@ def fetch_fn(idx):
9811028 else :
9821029 init_b = init_b .astype (dtype )
9831030
1031+ logger .debug (" * Start creating model" )
9841032 with graph .as_default ():
9851033 # create model
9861034 model = EstimatorGraph (
@@ -999,6 +1047,7 @@ def fetch_fn(idx):
9991047 extended_summary = extended_summary ,
10001048 dtype = dtype
10011049 )
1050+ logger .debug (" * Finished creating model" )
10021051
10031052 MonitoredTFEstimator .__init__ (self , model )
10041053
0 commit comments