Skip to content

Commit 91f9b54

Browse files
add tqdm
1 parent 99abe54 commit 91f9b54

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.optim.lr_scheduler import MultiStepLR
77
from torch.nn.utils import clip_grad_norm_
88
from frames_dataset import DatasetRepeater
9+
from tqdm import tqdm
910
import math
1011
import bitsandbytes as bnb
1112
from accelerate import Accelerator
@@ -65,7 +66,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
6566
models=[inpainting_network, dense_motion_network, kp_detector]
6667
) as logger:
6768
for epoch in trange(start_epoch, train_params['num_epochs']):
68-
for x in dataloader:
69+
for x in tqdm(dataloader):
6970
losses_generator, generated = generator_full(x, epoch)
7071
loss_values = [val.mean() for val in losses_generator.values()]
7172
loss = sum(loss_values)

0 commit comments

Comments
 (0)