Skip to content

Commit 64c91d8

Browse files
fixed bug that convergence isnt set to False after training_strategy is done
1 parent c05d2c0 commit 64c91d8

File tree

4 files changed

+49
-31
lines changed

4 files changed

+49
-31
lines changed

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,15 @@ def train_sequence(self, training_strategy):
333333
if isinstance(training_strategy, Enum):
334334
training_strategy = training_strategy.value
335335
elif isinstance(training_strategy, str):
336-
training_strategy = self.TrainingStrategy[training_strategy].value
336+
training_strategy = self.TrainingStrategies[training_strategy].value
337337

338338
if training_strategy is None:
339-
training_strategy = self.TrainingStrategy.DEFAULT.value
339+
training_strategy = self.TrainingStrategies.DEFAULT.value
340340

341341
logger.info("training strategy:\n%s", pprint.pformat(training_strategy))
342342

343343
for idx, d in enumerate(training_strategy):
344+
self.model.model_vars.converged = False
344345
logger.info("Beginning with training sequence #%d", idx + 1)
345346
self.train(**d)
346347
logger.info("Training sequence #%d complete", idx + 1)

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from enum import Enum
21
import logging
32
from typing import Union
43

@@ -10,6 +9,7 @@
109
from .external import closedform_nb_glm_logmu, closedform_nb_glm_logphi
1110
from .estimator_graph import EstimatorGraph
1211
from .model import ProcessModel
12+
from .training_strategies import TrainingStrategies
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -20,33 +20,6 @@ class Estimator(EstimatorAll, AbstractEstimator, ProcessModel):
2020
Uses the natural logarithm as linker function.
2121
"""
2222

23-
class TrainingStrategy(Enum):
24-
AUTO = None
25-
DEFAULT = [
26-
{
27-
"convergence_criteria": "all_converged_ll",
28-
"stopping_criteria": 1e-6,
29-
"use_batching": False,
30-
"optim_algo": "irls",
31-
},
32-
]
33-
QUICK = [
34-
{
35-
"convergence_criteria": "all_converged_ll",
36-
"stopping_criteria": 1e-4,
37-
"use_batching": False,
38-
"optim_algo": "irls",
39-
},
40-
]
41-
EXACT = [
42-
{
43-
"convergence_criteria": "all_converged_ll",
44-
"stopping_criteria": 1e-8,
45-
"use_batching": False,
46-
"optim_algo": "irls",
47-
},
48-
]
49-
5023
def __init__(
5124
self,
5225
input_data: InputData,
@@ -62,6 +35,7 @@ def __init__(
6235
extended_summary=False,
6336
dtype="float64",
6437
):
38+
self.TrainingStrategies = TrainingStrategies
6539
EstimatorAll.__init__(
6640
self=self,
6741
input_data=input_data,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from enum import Enum
2+
3+
class TrainingStrategies(Enum):
4+
5+
AUTO = None
6+
DEFAULT = [
7+
{
8+
"convergence_criteria": "all_converged_ll",
9+
"stopping_criteria": 1e-6,
10+
"use_batching": False,
11+
"optim_algo": "irls",
12+
},
13+
]
14+
QUICK = [
15+
{
16+
"convergence_criteria": "all_converged_ll",
17+
"stopping_criteria": 1e-3,
18+
"use_batching": True,
19+
"optim_algo": "irls",
20+
},
21+
{
22+
"convergence_criteria": "all_converged_ll",
23+
"stopping_criteria": 1e-6,
24+
"use_batching": False,
25+
"optim_algo": "irls",
26+
},
27+
]
28+
INEXACT = [
29+
{
30+
"convergence_criteria": "all_converged_ll",
31+
"stopping_criteria": 1e-4,
32+
"use_batching": False,
33+
"optim_algo": "irls",
34+
},
35+
]
36+
EXACT = [
37+
{
38+
"convergence_criteria": "all_converged_ll",
39+
"stopping_criteria": 1e-8,
40+
"use_batching": False,
41+
"optim_algo": "irls",
42+
},
43+
]

batchglm/unit_test/glm_all/test_acc_glm_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
else:
2929
raise ValueError("noise_model not recognized")
3030

31-
batch_size = 900
31+
batch_size = 200
3232
provide_optimizers = {"gd": True, "adam": True, "adagrad": True, "rmsprop": True, "nr": True, "irls": True}
3333
estimator = Estimator(
3434
input_data=simulator.input_data,

0 commit comments

Comments
 (0)