1010
1111import batchglm .data as data_utils
1212from batchglm .api .models .nb_glm import Simulator , Estimator , InputData
13+ import batchglm .pkg_constants as pkg_constants
1314
1415
1516def estimate_adam_full (input_data : InputData , working_dir : str ):
@@ -51,7 +52,7 @@ def estimate_nr_full(input_data: InputData, working_dir: str):
5152 return estimator
5253
5354def estimate_nr_batched (input_data : InputData , working_dir : str ):
54- estimator = Estimator (input_data , batch_size = 500 )
55+ estimator = Estimator (input_data , batch_size = 50 )
5556 estimator .initialize (
5657 working_dir = working_dir ,
5758 save_checkpoint_steps = 20 ,
@@ -140,6 +141,7 @@ def test_newton_batched(self):
140141 os .makedirs (wd , exist_ok = True )
141142
142143 t0 = time .time ()
144+ pkg_constants .JACOBIAN_MODE = "analytic"
143145 estimator = estimate_nr_batched (idata , wd )
144146 t1 = time .time ()
145147 self ._estims .append (estimator )
@@ -165,6 +167,7 @@ def test_newton_full(self):
165167 os .makedirs (wd , exist_ok = True )
166168
167169 t0 = time .time ()
170+ pkg_constants .JACOBIAN_MODE = "analytic"
168171 estimator = estimate_nr_full (idata , wd )
169172 t1 = time .time ()
170173 self ._estims .append (estimator )
@@ -189,6 +192,7 @@ def test_newton_series(self):
189192 os .makedirs (wd , exist_ok = True )
190193
191194 t0 = time .time ()
195+ pkg_constants .JACOBIAN_MODE = "analytic"
192196 estimator = estimate_nr_series (idata , wd )
193197 t1 = time .time ()
194198 self ._estims .append (estimator )
0 commit comments