File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -512,13 +512,11 @@ def train(self, log=False):
512
512
else :
513
513
s_loss = torch .zeros_like (c_loss )
514
514
515
- del res
516
-
517
515
(c_loss + s_loss + z_loss ).backward ()
518
516
519
517
running_c_loss += c_loss .item ()
520
518
running_s_loss += s_loss .item ()
521
- if z_loss :
519
+ if res . zero_scores . size ( 0 ) == 0 :
522
520
running_z_loss += z_loss .item ()
523
521
524
522
# log every 100 docs
@@ -527,12 +525,11 @@ def train(self, log=False):
527
525
'train_c_loss' : c_loss .item (),
528
526
'train_s_loss' : s_loss .item (),
529
527
}
530
- if z_loss :
528
+ if res . zero_scores . size ( 0 ) == 0 :
531
529
logged ['train_z_loss' ] = z_loss .item ()
532
530
wandb .log (logged )
533
531
534
-
535
- del c_loss , s_loss , z_loss
532
+ del c_loss , s_loss , z_loss , res
536
533
537
534
for optim in self .optimizers .values ():
538
535
optim .step ()
You can’t perform that action at this time.
0 commit comments