@@ -30,12 +30,14 @@ def estimate_adam_full(input_data: InputData, working_dir: str):
3030 input_data .save (os .path .join (working_dir , "input_data.h5" ))
3131
3232 estimator .train_sequence (training_strategy = [
33- {'convergence_criteria' : 't_test' ,
34- 'learning_rate' : 0.1 ,
35- 'loss_window_size' : 20 ,
36- 'optim_algo' : 'ADAM' ,
37- 'stop_at_loss_change' : 1e-8 ,
38- 'use_batching' : False }
33+ {
34+ "convergence_criteria" : "scaled_moving_average" ,
35+ "stopping_criteria" : 1e-6 ,
36+ "loss_window_size" : 5 ,
37+ "use_batching" : False ,
38+ "optim_algo" : "ADAM" ,
39+ "learning_rate" : 0.1
40+ }
3941 ])
4042
4143 return estimator
@@ -55,8 +57,8 @@ def estimate_nr_full(input_data: InputData, working_dir: str):
5557 estimator .train_sequence (training_strategy = [
5658 {
5759 "convergence_criteria" : "scaled_moving_average" ,
58- "stopping_criteria" : 1e-10 ,
59- "loss_window_size" : 10 ,
60+ "stopping_criteria" : 1e-6 ,
61+ "loss_window_size" : 5 ,
6062 "use_batching" : False ,
6163 "optim_algo" : "newton" ,
6264 },
@@ -79,8 +81,8 @@ def estimate_nr_batched(input_data: InputData, working_dir: str):
7981 estimator .train_sequence (training_strategy = [
8082 {
8183 "convergence_criteria" : "scaled_moving_average" ,
82- "stopping_criteria" : 1e-8 ,
83- "loss_window_size" : 10 ,
84+ "stopping_criteria" : 1e-6 ,
85+ "loss_window_size" : 5 ,
8486 "use_batching" : True ,
8587 "optim_algo" : "newton" ,
8688 },
0 commit comments