Skip to content

Commit f64aad3

Browse files
changed to small batches in newton unit test which recapitulates training divergence for small batches
1 parent dc408d1 commit f64aad3

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

batchglm/unit_test/test_nb_glm_newton.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import batchglm.data as data_utils
1212
from batchglm.api.models.nb_glm import Simulator, Estimator, InputData
13+
import batchglm.pkg_constants as pkg_constants
1314

1415

1516
def estimate_adam_full(input_data: InputData, working_dir: str):
@@ -51,7 +52,7 @@ def estimate_nr_full(input_data: InputData, working_dir: str):
5152
return estimator
5253

5354
def estimate_nr_batched(input_data: InputData, working_dir: str):
54-
estimator = Estimator(input_data, batch_size=500)
55+
estimator = Estimator(input_data, batch_size=50)
5556
estimator.initialize(
5657
working_dir=working_dir,
5758
save_checkpoint_steps=20,
@@ -140,6 +141,7 @@ def test_newton_batched(self):
140141
os.makedirs(wd, exist_ok=True)
141142

142143
t0 = time.time()
144+
pkg_constants.JACOBIAN_MODE = "analytic"
143145
estimator = estimate_nr_batched(idata, wd)
144146
t1 = time.time()
145147
self._estims.append(estimator)
@@ -165,6 +167,7 @@ def test_newton_full(self):
165167
os.makedirs(wd, exist_ok=True)
166168

167169
t0 = time.time()
170+
pkg_constants.JACOBIAN_MODE = "analytic"
168171
estimator = estimate_nr_full(idata, wd)
169172
t1 = time.time()
170173
self._estims.append(estimator)
@@ -189,6 +192,7 @@ def test_newton_series(self):
189192
os.makedirs(wd, exist_ok=True)
190193

191194
t0 = time.time()
195+
pkg_constants.JACOBIAN_MODE = "analytic"
192196
estimator = estimate_nr_series(idata, wd)
193197
t1 = time.time()
194198
self._estims.append(estimator)

0 commit comments

Comments
 (0)