Skip to content

Commit 2340bd4

Browse files
depreceated maths pdf, will add in once new version is finished
1 parent 64c91d8 commit 2340bd4

File tree

4 files changed

+32
-268
lines changed

4 files changed

+32
-268
lines changed

batchglm/train/tf/base/estimator.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
from typing import Dict, Any, Union, List, Iterable
3-
43
import os
4+
import time
55
import datetime
66

77
import numpy as np
@@ -127,12 +127,19 @@ def should_stop(step):
127127
return False
128128

129129
while True:
130+
t0 = time.time()
130131
train_step, global_loss, _ = self.session.run(
131132
(self.model.global_step, loss, train_op),
132133
feed_dict=feed_dict
133134
)
135+
t1 = time.time()
134136

135-
tf.logging.info("Step: %d\tloss: %f", train_step, global_loss)
137+
tf.logging.info(
138+
"Step: \t%d\tloss: %f\t in %s sec",
139+
train_step,
140+
global_loss,
141+
str(np.round(t1 - t0, 3))
142+
)
136143

137144
# update last_loss every N+1st step:
138145
if train_step % len(loss_hist) == 1:
@@ -205,12 +212,19 @@ def train(self, *args,
205212
if convergence_criteria == "step":
206213
train_step = self.session.run(self.model.global_step, feed_dict=feed_dict)
207214
while train_step < stopping_criteria:
215+
t0 = time.time()
208216
train_step, global_loss, _ = self.session.run(
209217
(self.model.global_step, loss, train_op),
210218
feed_dict=feed_dict
211219
)
220+
t1 = time.time()
212221

213-
tf.logging.info("Step: %d\tloss: %f", train_step, global_loss)
222+
tf.logging.info(
223+
"Step: %d\tloss: %s",
224+
train_step,
225+
global_loss,
226+
str(np.round(t1 - t0, 3))
227+
)
214228
elif convergence_criteria in ["all_converged_ll", "all_converged_theta"]:
215229
# Evaluate initial value of convergence metric:
216230
if convergence_criteria == "all_converged_theta":
@@ -222,6 +236,7 @@ def train(self, *args,
222236

223237
while np.any(self.model.model_vars.converged == False):
224238
# Update convergence metric reference:
239+
t0 = time.time()
225240
metric_prev = metric_current
226241
train_step, global_loss, _ = self.session.run(
227242
(self.model.global_step, loss, train_op),
@@ -244,12 +259,14 @@ def train(self, *args,
244259
self.model.model_vars.converged,
245260
metric_delta < stopping_criteria
246261
)
262+
t1 = time.time()
247263

248264
tf.logging.info(
249-
"Step: \t%d\t loss: %f\t models converged %i",
265+
"Step: \t%d\t loss: \t%f\t models converged \t%i\t in %s sec",
250266
train_step,
251267
global_loss,
252-
np.sum(self.model.model_vars.converged).astype("int32")
268+
np.sum(self.model.model_vars.converged).astype("int32"),
269+
str(np.round(t1 - t0, 3))
253270
)
254271
else:
255272
self._train_to_convergence(

batchglm/train/tf/glm_nb/training_strategies.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,32 @@ class TrainingStrategies(Enum):
1111
"optim_algo": "irls",
1212
},
1313
]
14-
QUICK = [
14+
INEXACT = [
1515
{
1616
"convergence_criteria": "all_converged_ll",
17-
"stopping_criteria": 1e-3,
18-
"use_batching": True,
17+
"stopping_criteria": 1e-4,
18+
"use_batching": False,
1919
"optim_algo": "irls",
2020
},
21+
]
22+
EXACT = [
2123
{
2224
"convergence_criteria": "all_converged_ll",
23-
"stopping_criteria": 1e-6,
25+
"stopping_criteria": 1e-8,
2426
"use_batching": False,
2527
"optim_algo": "irls",
2628
},
2729
]
28-
INEXACT = [
30+
QUICK = [
2931
{
3032
"convergence_criteria": "all_converged_ll",
31-
"stopping_criteria": 1e-4,
32-
"use_batching": False,
33+
"stopping_criteria": 1e-3,
34+
"use_batching": True,
3335
"optim_algo": "irls",
3436
},
35-
]
36-
EXACT = [
3737
{
3838
"convergence_criteria": "all_converged_ll",
39-
"stopping_criteria": 1e-8,
39+
"stopping_criteria": 1e-6,
4040
"use_batching": False,
4141
"optim_algo": "irls",
4242
},

maths/hessians/nb_glm.pdf

-122 KB
Binary file not shown.

maths/hessians/nb_glm.tex

Lines changed: 0 additions & 253 deletions
This file was deleted.

0 commit comments

Comments
 (0)