Skip to content

Commit f9e5d80

Browse files
improved default handling for newton-rhapson in estimator.train
Set default learning rate to 1 and batching to False. Also output all settings in debug logger now.
1 parent 62cc15a commit f9e5d80

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def _scaffold(self):
964964
return scaffold
965965

966966
def train(self, *args,
967-
learning_rate=0.5,
967+
learning_rate=None,
968968
convergence_criteria="t_test",
969969
loss_window_size=100,
970970
stopping_criteria=0.05,
@@ -1018,12 +1018,43 @@ def train(self, *args,
10181018
# check if r was initialized with MLE
10191019
train_r = self._train_r
10201020

1021-
if use_batching:
1022-
loss = self.model.loss
1023-
if optim_algo.lower() == "newton" or \
1021+
# Check whether newton-rhapson is desired:
1022+
newton_rhapson_mode = False
1023+
if optim_algo.lower() == "newton" or \
10241024
optim_algo.lower() == "newton-raphson" or \
10251025
optim_algo.lower() == "newton_raphson" or \
10261026
optim_algo.lower() == "nr":
1027+
newton_rhapson_mode = True
1028+
# Set learning rae defaults if not set by user.
1029+
if learning_rate is None:
1030+
if newton_rhapson_mode:
1031+
learning_rate = 1
1032+
else:
1033+
learning_rate = 0.5
1034+
1035+
# Check that newton-rhapson is called properly:
1036+
if newton_rhapson_mode:
1037+
if learning_rate != 1:
1038+
logger.warning("Newton-rhapson in nb_glm is used with learing rate " + str(learning_rate) +
1039+
". Newton-rhapson should only be used with learing rate =1.")
1040+
1041+
# Report all parameters after all defaults were imputed in settings:
1042+
logger.debug("Optimizer settings in nb_glm Estimator.train():")
1043+
logger.debug("learning_rate " + str(learning_rate))
1044+
logger.debug("convergence_criteria " + str(convergence_criteria))
1045+
logger.debug("loss_window_size " + str(loss_window_size))
1046+
logger.debug("stopping_criteria " + str(stopping_criteria))
1047+
logger.debug("train_mu " + str(train_mu))
1048+
logger.debug("train_r " + str(train_r))
1049+
logger.debug("use_batching " + str(use_batching))
1050+
logger.debug("optim_algo " + str(optim_algo))
1051+
if len(kwargs) > 0:
1052+
logger.debug("**kwargs: ")
1053+
logger.debug(kwargs)
1054+
1055+
if use_batching:
1056+
loss = self.model.loss
1057+
if newton_rhapson_mode:
10271058
train_op = self.model.newton_raphson_batched_op
10281059
elif train_mu:
10291060
if train_r:
@@ -1038,10 +1069,7 @@ def train(self, *args,
10381069
return
10391070
else:
10401071
loss = self.model.full_loss
1041-
if optim_algo.lower() == "newton" or \
1042-
optim_algo.lower() == "newton-raphson" or \
1043-
optim_algo.lower() == "newton_raphson" or \
1044-
optim_algo.lower() == "nr":
1072+
if newton_rhapson_mode:
10451073
train_op = self.model.newton_raphson_op
10461074
elif train_mu:
10471075
if train_r:

0 commit comments

Comments
 (0)