Skip to content

Commit 5b8dc69

Browse files
author
Kashif Rasul
committed
partially fix tqdm with validation dataset
1 parent 729c4a1 commit 5b8dc69

File tree

1 file changed

+47
-44
lines changed

1 file changed

+47
-44
lines changed

pts/trainer.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from gluonts.core.component import validated
1414

15+
1516
class Trainer:
1617
@validated()
1718
def __init__(
@@ -59,35 +60,14 @@ def __call__(
5960
for epoch_no in range(self.epochs):
6061
# mark epoch start time
6162
tic = time.time()
62-
avg_epoch_loss = 0.0
63-
64-
if validation_iter is not None:
65-
avg_epoch_loss_val = 0.0
66-
67-
train_iter_obj = list(zip(range(1, train_iter.batch_size+1), tqdm(train_iter)))
68-
if validation_iter is not None:
69-
val_iter_obj = list(zip(range(1, validation_iter.batch_size+1), tqdm(validation_iter)))
70-
63+
cumm_epoch_loss = 0.0
7164

65+
# training loop
7266
with tqdm(train_iter) as it:
73-
for batch_no, data_entry in train_iter_obj:
74-
67+
for batch_no, data_entry in enumerate(it, start=1):
68+
it.update(1)
7569
optimizer.zero_grad()
7670

77-
# Strong assumption that validation_iter and train_iter are same iter size
78-
if validation_iter is not None:
79-
with torch.no_grad():
80-
val_data_entry = val_iter_obj[batch_no-1][1]
81-
inputs_val = [v.to(self.device) for v in val_data_entry.values()]
82-
output_val = net(*inputs_val)
83-
84-
if isinstance(output_val, (list, tuple)):
85-
loss_val = output_val[0]
86-
else:
87-
loss_val = output_val
88-
89-
avg_epoch_loss_val += loss_val.item()
90-
9171
inputs = [v.to(self.device) for v in data_entry.values()]
9272
output = net(*inputs)
9373

@@ -96,24 +76,18 @@ def __call__(
9676
else:
9777
loss = output
9878

99-
avg_epoch_loss += loss.item()
100-
if validation_iter is not None:
101-
post_fix_dict = ordered_dict={
102-
"avg_epoch_loss": avg_epoch_loss / batch_no,
103-
"avg_epoch_loss_val": avg_epoch_loss_val / batch_no,
104-
"epoch": epoch_no,
105-
}
106-
wandb.log({"loss_val": loss_val.item()})
107-
else:
108-
post_fix_dict={
109-
"avg_epoch_loss": avg_epoch_loss / batch_no,
110-
"epoch": epoch_no,
111-
}
112-
79+
cumm_epoch_loss += loss.item()
80+
avg_epoch_loss = cumm_epoch_loss / batch_no
81+
it.set_postfix(
82+
{
83+
"epoch": f"{epoch_no + 1}/{self.epochs}",
84+
"avg_loss": avg_epoch_loss,
85+
},
86+
refresh=False
87+
)
88+
11389
wandb.log({"loss": loss.item()})
11490

115-
it.set_postfix(post_fix_dict, refresh=False)
116-
11791
loss.backward()
11892
if self.clip_gradient is not None:
11993
nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient)
@@ -123,8 +97,37 @@ def __call__(
12397

12498
if self.num_batches_per_epoch == batch_no:
12599
break
126-
100+
101+
# validation loop
102+
if validation_iter is not None:
103+
cumm_epoch_loss_val = 0.0
104+
105+
for batch_no, data_entry in enumerate(validation_iter, start=1):
106+
it.update(1)
107+
inputs = [v.to(self.device) for v in data_entry.values()]
108+
with torch.no_grad():
109+
output = net(*inputs)
110+
if isinstance(output, (list, tuple)):
111+
loss = output[0]
112+
else:
113+
loss = output
114+
115+
cumm_epoch_loss_val += loss.item()
116+
avg_epoch_loss_val = cumm_epoch_loss_val / batch_no
117+
it.set_postfix(
118+
{
119+
"epoch": f"{epoch_no + 1}/{self.epochs}",
120+
"avg_loss": avg_epoch_loss,
121+
"avg_val_loss": avg_epoch_loss_val,
122+
},
123+
)
124+
125+
if self.num_batches_per_epoch == batch_no:
126+
break
127+
128+
wandb.log({"avg_val_loss": avg_epoch_loss_val})
129+
print(it) # TODO fix this
130+
it.close()
131+
127132
# mark epoch end time and log time cost of current epoch
128133
toc = time.time()
129-
130-
# writer.close()

0 commit comments

Comments
 (0)