Skip to content

Commit 5da3be5

Browse files
author
Kashif Rasul
committed
fix tqdm
1 parent 5b8dc69 commit 5da3be5

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

pts/trainer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from typing import List, Optional, Union
33

4-
from tqdm import tqdm
4+
from tqdm.auto import tqdm
55
import wandb
66

77
import torch
@@ -61,11 +61,11 @@ def __call__(
6161
# mark epoch start time
6262
tic = time.time()
6363
cumm_epoch_loss = 0.0
64+
total = self.num_batches_per_epoch - 1
6465

6566
# training loop
66-
with tqdm(train_iter) as it:
67+
with tqdm(train_iter, total=total) as it:
6768
for batch_no, data_entry in enumerate(it, start=1):
68-
it.update(1)
6969
optimizer.zero_grad()
7070

7171
inputs = [v.to(self.device) for v in data_entry.values()]
@@ -83,7 +83,7 @@ def __call__(
8383
"epoch": f"{epoch_no + 1}/{self.epochs}",
8484
"avg_loss": avg_epoch_loss,
8585
},
86-
refresh=False
86+
refresh=False,
8787
)
8888

8989
wandb.log({"loss": loss.item()})
@@ -97,13 +97,14 @@ def __call__(
9797

9898
if self.num_batches_per_epoch == batch_no:
9999
break
100+
it.close()
100101

101-
# validation loop
102-
if validation_iter is not None:
103-
cumm_epoch_loss_val = 0.0
102+
# validation loop
103+
if validation_iter is not None:
104+
cumm_epoch_loss_val = 0.0
105+
with tqdm(validation_iter, total=total, colour="green") as it:
104106

105-
for batch_no, data_entry in enumerate(validation_iter, start=1):
106-
it.update(1)
107+
for batch_no, data_entry in enumerate(it, start=1):
107108
inputs = [v.to(self.device) for v in data_entry.values()]
108109
with torch.no_grad():
109110
output = net(*inputs)
@@ -120,13 +121,13 @@ def __call__(
120121
"avg_loss": avg_epoch_loss,
121122
"avg_val_loss": avg_epoch_loss_val,
122123
},
124+
refresh=False,
123125
)
124126

125127
if self.num_batches_per_epoch == batch_no:
126128
break
127129

128130
wandb.log({"avg_val_loss": avg_epoch_loss_val})
129-
print(it) # TODO fix this
130131
it.close()
131132

132133
# mark epoch end time and log time cost of current epoch

0 commit comments

Comments
 (0)