Skip to content

Commit 916f31a

Browse files
author
Alexander Lyashuk
committed
Logging only in the main process when training in the distributed mode
1 parent 88f7b08 commit 916f31a

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

PyTorch-Vision-Experiment-Logging/engine.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,20 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
3232
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
3333
loss_dict = model(images, targets)
3434

35+
# applying logging only in the main process
3536
# ### OUR CODE ###
36-
# let's track the losses here by adding scalars
37-
tensorboard.logger.add_scalar_dict(
38-
# passing the dictionary of losses (pairs - loss_key: loss_value)
39-
loss_dict,
40-
# passing the global step (number of iterations)
41-
global_step=tensorboard.global_iter,
42-
# adding the tag to combine plots in a subgroup
43-
tag="loss"
44-
)
45-
# incrementing the global step (number of iterations)
46-
tensorboard.global_iter += 1
37+
if utils.is_main_process():
38+
# let's track the losses here by adding scalars
39+
tensorboard.logger.add_scalar_dict(
40+
# passing the dictionary of losses (pairs - loss_key: loss_value)
41+
loss_dict,
42+
# passing the global step (number of iterations)
43+
global_step=tensorboard.global_iter,
44+
# adding the tag to combine plots in a subgroup
45+
tag="loss"
46+
)
47+
# incrementing the global step (number of iterations)
48+
tensorboard.global_iter += 1
4749
# ### END OF OUR CODE ###
4850

4951
losses = sum(loss for loss in loss_dict.values())
@@ -109,25 +111,27 @@ def evaluate(model, data_loader, device):
109111
model_time = time.time()
110112
outputs = model(images)
111113

114+
# applying logging only in the main process
112115
# ### OUR CODE ###
113-
# let's track bounding box and labels predictions for the first 50 images
114-
# as we hardly want to track all validation images
115-
# but want to see how the predicted bounding boxes and labels are changing during the process
116-
if i < 50:
117-
# let's add tracking images with predicted bounding boxes
118-
tensorboard.logger.add_image_with_boxes(
119-
# adding pred_images tag to combine images in one subgroup
120-
"pred_images/PD-{}".format(i),
121-
# passing image tensor
122-
img,
123-
# passing predicted bounding boxes
124-
outputs[0]["boxes"].cpu(),
125-
# mapping & passing predicted labels
126-
labels=[
127-
tensorboard.COCO_INSTANCE_CATEGORY_NAMES[i]
128-
for i in outputs[0]["labels"].cpu().numpy()
129-
],
130-
)
116+
if utils.is_main_process():
117+
# let's track bounding box and labels predictions for the first 50 images
118+
# as we hardly want to track all validation images
119+
# but want to see how the predicted bounding boxes and labels are changing during the process
120+
if i < 50:
121+
# let's add tracking images with predicted bounding boxes
122+
tensorboard.logger.add_image_with_boxes(
123+
# adding pred_images tag to combine images in one subgroup
124+
"pred_images/PD-{}".format(i),
125+
# passing image tensor
126+
img,
127+
# passing predicted bounding boxes
128+
outputs[0]["boxes"].cpu(),
129+
# mapping & passing predicted labels
130+
labels=[
131+
tensorboard.COCO_INSTANCE_CATEGORY_NAMES[i]
132+
for i in outputs[0]["labels"].cpu().numpy()
133+
],
134+
)
131135
# ### END OUR CODE ###
132136
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
133137
model_time = time.time() - model_time
@@ -144,7 +148,9 @@ def evaluate(model, data_loader, device):
144148
coco_evaluator.synchronize_between_processes()
145149

146150
# accumulate predictions from all images
147-
coco_evaluator.accumulate()
148-
coco_evaluator.summarize()
151+
# add main process check for multi-gpu training (torch.distributed)
152+
if utils.is_main_process():
153+
coco_evaluator.accumulate()
154+
coco_evaluator.summarize()
149155
torch.set_num_threads(n_threads)
150156
return coco_evaluator

PyTorch-Vision-Experiment-Logging/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def get_transform(train):
6565
def main(args):
6666
utils.init_distributed_mode(args)
6767
print(args)
68+
# applying logging only in the main process
69+
# ### OUR CODE ###
70+
if utils.is_main_process():
71+
# passing argparse config with hyperparameters
72+
tensorboard.args = vars(args)
73+
# init wandb using config and experiment name
74+
wandb.init(config=vars(args), name=tensorboard.name)
75+
# enable tensorboard sync
76+
wandb.init(sync_tensorboard=True)
77+
# ### END OF OUR CODE ###
6878

6979
device = torch.device(args.device)
7080

@@ -204,12 +214,4 @@ def main(args):
204214
if args.output_dir:
205215
utils.mkdir(args.output_dir)
206216

207-
# ### OUR CODE ###
208-
# passing argparse config with hyperparameters
209-
tensorboard.args = vars(args)
210-
# init wandb using config and experiment name
211-
wandb.init(config=vars(args), name=tensorboard.name)
212-
# enable tensorboard sync
213-
wandb.init(sync_tensorboard=True)
214-
# ### END OF OUR CODE ###
215217
main(args)

0 commit comments

Comments
 (0)