Skip to content

Commit 3432635

Browse files
updated newton unit test to run faster
1 parent 4d9226e commit 3432635

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

batchglm/unit_test/test_nb_glm_newton.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ def estimate_adam_full(input_data: InputData, working_dir: str):
3030
input_data.save(os.path.join(working_dir, "input_data.h5"))
3131

3232
estimator.train_sequence(training_strategy=[
33-
{'convergence_criteria': 't_test',
34-
'learning_rate': 0.1,
35-
'loss_window_size': 20,
36-
'optim_algo': 'ADAM',
37-
'stop_at_loss_change': 1e-8,
38-
'use_batching': False}
33+
{
34+
"convergence_criteria": "scaled_moving_average",
35+
"stopping_criteria": 1e-6,
36+
"loss_window_size": 5,
37+
"use_batching": False,
38+
"optim_algo": "ADAM",
39+
"learning_rate": 0.1
40+
}
3941
])
4042

4143
return estimator
@@ -55,8 +57,8 @@ def estimate_nr_full(input_data: InputData, working_dir: str):
5557
estimator.train_sequence(training_strategy=[
5658
{
5759
"convergence_criteria": "scaled_moving_average",
58-
"stopping_criteria": 1e-10,
59-
"loss_window_size": 10,
60+
"stopping_criteria": 1e-6,
61+
"loss_window_size": 5,
6062
"use_batching": False,
6163
"optim_algo": "newton",
6264
},
@@ -79,8 +81,8 @@ def estimate_nr_batched(input_data: InputData, working_dir: str):
7981
estimator.train_sequence(training_strategy=[
8082
{
8183
"convergence_criteria": "scaled_moving_average",
82-
"stopping_criteria": 1e-8,
83-
"loss_window_size": 10,
84+
"stopping_criteria": 1e-6,
85+
"loss_window_size": 5,
8486
"use_batching": True,
8587
"optim_algo": "newton",
8688
},

0 commit comments

Comments
 (0)