Skip to content

Commit a7727c7

Browse files
aivanoufacebook-github-bot
authored andcommitted
Change lightning-cv example to make it work for gpu training (#293)
Summary: Pull Request resolved: #293 Change `lightning-cv` example to make it work for gpu training It seems there are bugs with `ddp2` running on GPUs, so changing accelerator to `ddp` Reviewed By: d4l3k Differential Revision: D31818113 fbshipit-source-id: 6bbe4f0c3d62d90cb8d78ca4d912378c659ba24f
1 parent 851745e commit a7727c7

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ kfp==1.6.2
44
pyre-extensions>=0.0.21
55
black>=21.5b1
66
usort==0.6.4
7-
pytorch-lightning>=0.5.3
7+
pytorch-lightning>=1.4.9
88
torch>=1.9.0
99
torchvision>=0.10.0
1010
classy-vision>=0.6.0

scripts/kube_dist_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def register_gpu_resource() -> None:
3333
res = Resource(
3434
cpu=2,
3535
gpu=1,
36-
memMB=4 * GiB,
36+
memMB=8 * GiB,
3737
)
3838
print(f"Registering resource: {res}")
3939
named_resources["GPU_X1"] = res

torchx/examples/apps/lightning_classy_vision/train.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def get_gpu_devices() -> int:
9191
def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
9292
if not args.output_path:
9393
return None
94+
# Note: It is important that each rank behaves the same.
95+
# All of the ranks, or none of them should return ModelCheckpoint
96+
# Otherwise, there will be deadlock for distributed training
9497
return ModelCheckpoint(
9598
monitor="train_loss",
9699
dirpath=args.output_path,
@@ -132,12 +135,14 @@ def main(argv: List[str]) -> None:
132135
logger = TensorBoardLogger(
133136
save_dir=args.log_path, version=1, name="lightning_logs"
134137
)
135-
136138
# Initialize a trainer
137139
num_nodes = int(os.environ.get("GROUP_WORLD_SIZE", 1))
140+
num_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
138141
trainer = pl.Trainer(
139142
num_nodes=num_nodes,
140-
accelerator="ddp2",
143+
num_processes=num_processes,
144+
gpus=get_gpu_devices(),
145+
accelerator="ddp",
141146
logger=logger,
142147
max_epochs=args.epochs,
143148
callbacks=callbacks,
@@ -150,7 +155,8 @@ def main(argv: List[str]) -> None:
150155
f"train acc: {model.train_acc.compute()}, val acc: {model.val_acc.compute()}"
151156
)
152157

153-
if not args.skip_export and args.output_path:
158+
rank = int(os.environ.get("RANK", 0))
159+
if rank == 0 and not args.skip_export and args.output_path:
154160
# Export the inference model
155161
export_inference_model(model, args.output_path, tmpdir)
156162

0 commit comments

Comments
 (0)