1212
1313from gluonts .core .component import validated
1414
15+
1516class Trainer :
1617 @validated ()
1718 def __init__ (
@@ -59,35 +60,14 @@ def __call__(
5960 for epoch_no in range (self .epochs ):
6061 # mark epoch start time
6162 tic = time .time ()
62- avg_epoch_loss = 0.0
63-
64- if validation_iter is not None :
65- avg_epoch_loss_val = 0.0
66-
67- train_iter_obj = list (zip (range (1 , train_iter .batch_size + 1 ), tqdm (train_iter )))
68- if validation_iter is not None :
69- val_iter_obj = list (zip (range (1 , validation_iter .batch_size + 1 ), tqdm (validation_iter )))
70-
63+ cumm_epoch_loss = 0.0
7164
65+ # training loop
7266 with tqdm (train_iter ) as it :
73- for batch_no , data_entry in train_iter_obj :
74-
67+ for batch_no , data_entry in enumerate ( it , start = 1 ) :
68+ it . update ( 1 )
7569 optimizer .zero_grad ()
7670
77- # Strong assumption that validation_iter and train_iter are same iter size
78- if validation_iter is not None :
79- with torch .no_grad ():
80- val_data_entry = val_iter_obj [batch_no - 1 ][1 ]
81- inputs_val = [v .to (self .device ) for v in val_data_entry .values ()]
82- output_val = net (* inputs_val )
83-
84- if isinstance (output_val , (list , tuple )):
85- loss_val = output_val [0 ]
86- else :
87- loss_val = output_val
88-
89- avg_epoch_loss_val += loss_val .item ()
90-
9171 inputs = [v .to (self .device ) for v in data_entry .values ()]
9272 output = net (* inputs )
9373
@@ -96,24 +76,18 @@ def __call__(
9676 else :
9777 loss = output
9878
99- avg_epoch_loss += loss .item ()
100- if validation_iter is not None :
101- post_fix_dict = ordered_dict = {
102- "avg_epoch_loss" : avg_epoch_loss / batch_no ,
103- "avg_epoch_loss_val" : avg_epoch_loss_val / batch_no ,
104- "epoch" : epoch_no ,
105- }
106- wandb .log ({"loss_val" : loss_val .item ()})
107- else :
108- post_fix_dict = {
109- "avg_epoch_loss" : avg_epoch_loss / batch_no ,
110- "epoch" : epoch_no ,
111- }
112-
79+ cumm_epoch_loss += loss .item ()
80+ avg_epoch_loss = cumm_epoch_loss / batch_no
81+ it .set_postfix (
82+ {
83+ "epoch" : f"{ epoch_no + 1 } /{ self .epochs } " ,
84+ "avg_loss" : avg_epoch_loss ,
85+ },
86+ refresh = False
87+ )
88+
11389 wandb .log ({"loss" : loss .item ()})
11490
115- it .set_postfix (post_fix_dict , refresh = False )
116-
11791 loss .backward ()
11892 if self .clip_gradient is not None :
11993 nn .utils .clip_grad_norm_ (net .parameters (), self .clip_gradient )
@@ -123,8 +97,37 @@ def __call__(
12397
12498 if self .num_batches_per_epoch == batch_no :
12599 break
126-
100+
101+ # validation loop
102+ if validation_iter is not None :
103+ cumm_epoch_loss_val = 0.0
104+
105+ for batch_no , data_entry in enumerate (validation_iter , start = 1 ):
106+ it .update (1 )
107+ inputs = [v .to (self .device ) for v in data_entry .values ()]
108+ with torch .no_grad ():
109+ output = net (* inputs )
110+ if isinstance (output , (list , tuple )):
111+ loss = output [0 ]
112+ else :
113+ loss = output
114+
115+ cumm_epoch_loss_val += loss .item ()
116+ avg_epoch_loss_val = cumm_epoch_loss_val / batch_no
117+ it .set_postfix (
118+ {
119+ "epoch" : f"{ epoch_no + 1 } /{ self .epochs } " ,
120+ "avg_loss" : avg_epoch_loss ,
121+ "avg_val_loss" : avg_epoch_loss_val ,
122+ },
123+ )
124+
125+ if self .num_batches_per_epoch == batch_no :
126+ break
127+
128+ wandb .log ({"avg_val_loss" : avg_epoch_loss_val })
129+ print (it ) # TODO fix this
130+ it .close ()
131+
127132 # mark epoch end time and log time cost of current epoch
128133 toc = time .time ()
129-
130- # writer.close()
0 commit comments