11import time
22from typing import List , Optional , Union
33
4- from tqdm import tqdm
4+ from tqdm . auto import tqdm
55import wandb
66
77import torch
@@ -61,11 +61,11 @@ def __call__(
6161 # mark epoch start time
6262 tic = time .time ()
6363 cumm_epoch_loss = 0.0
64+ total = self .num_batches_per_epoch - 1
6465
6566 # training loop
66- with tqdm (train_iter ) as it :
67+ with tqdm (train_iter , total = total ) as it :
6768 for batch_no , data_entry in enumerate (it , start = 1 ):
68- it .update (1 )
6969 optimizer .zero_grad ()
7070
7171 inputs = [v .to (self .device ) for v in data_entry .values ()]
@@ -83,7 +83,7 @@ def __call__(
8383 "epoch" : f"{ epoch_no + 1 } /{ self .epochs } " ,
8484 "avg_loss" : avg_epoch_loss ,
8585 },
86- refresh = False
86+ refresh = False ,
8787 )
8888
8989 wandb .log ({"loss" : loss .item ()})
@@ -97,13 +97,14 @@ def __call__(
9797
9898 if self .num_batches_per_epoch == batch_no :
9999 break
100+ it .close ()
100101
101- # validation loop
102- if validation_iter is not None :
103- cumm_epoch_loss_val = 0.0
102+ # validation loop
103+ if validation_iter is not None :
104+ cumm_epoch_loss_val = 0.0
105+ with tqdm (validation_iter , total = total , colour = "green" ) as it :
104106
105- for batch_no , data_entry in enumerate (validation_iter , start = 1 ):
106- it .update (1 )
107+ for batch_no , data_entry in enumerate (it , start = 1 ):
107108 inputs = [v .to (self .device ) for v in data_entry .values ()]
108109 with torch .no_grad ():
109110 output = net (* inputs )
@@ -120,13 +121,13 @@ def __call__(
120121 "avg_loss" : avg_epoch_loss ,
121122 "avg_val_loss" : avg_epoch_loss_val ,
122123 },
124+ refresh = False ,
123125 )
124126
125127 if self .num_batches_per_epoch == batch_no :
126128 break
127129
128130 wandb .log ({"avg_val_loss" : avg_epoch_loss_val })
129- print (it ) # TODO fix this
130131 it .close ()
131132
132133 # mark epoch end time and log time cost of current epoch
0 commit comments