Skip to content

Commit 8226743

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
(torchx/examples) set the correct number of gpus based on LOCAL_RANK or None (if cpu) (#296)
Summary: Pull Request resolved: #296 see title Reviewed By: aivanou Differential Revision: D31830121 fbshipit-source-id: a30cceaedb4b2002f20d319b7f27c2116e96278f
1 parent ddcae5c commit 8226743

File tree

1 file changed

+13
-13
lines changed
  • torchx/examples/apps/lightning_classy_vision

1 file changed

+13
-13
lines changed

torchx/examples/apps/lightning_classy_vision/train.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,20 @@
2727
import torch
2828
from pytorch_lightning.callbacks import ModelCheckpoint
2929
from pytorch_lightning.loggers import TensorBoardLogger
30-
31-
# ensure data and module are on the path
32-
sys.path.append(".")
33-
3430
from torchx.examples.apps.lightning_classy_vision.data import (
3531
TinyImageNetDataModule,
36-
download_data,
3732
create_random_data,
33+
download_data,
3834
)
3935
from torchx.examples.apps.lightning_classy_vision.model import (
4036
TinyImageNetModel,
4137
export_inference_model,
4238
)
43-
from torchx.examples.apps.lightning_classy_vision.profiler import (
44-
SimpleLoggingProfiler,
45-
)
39+
from torchx.examples.apps.lightning_classy_vision.profiler import SimpleLoggingProfiler
40+
41+
42+
# ensure data and module are on the path
43+
sys.path.append(".")
4644

4745

4846
def parse_args(argv: List[str]) -> argparse.Namespace:
@@ -84,10 +82,6 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
8482
return parser.parse_args(argv)
8583

8684

87-
def get_gpu_devices() -> int:
88-
return torch.cuda.device_count()
89-
90-
9185
def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
9286
if not args.output_path:
9387
return None
@@ -138,10 +132,16 @@ def main(argv: List[str]) -> None:
138132
# Initialize a trainer
139133
num_nodes = int(os.environ.get("GROUP_WORLD_SIZE", 1))
140134
num_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
135+
136+
if torch.cuda.is_available():
137+
gpus = num_processes
138+
else:
139+
gpus = None
140+
141141
trainer = pl.Trainer(
142142
num_nodes=num_nodes,
143143
num_processes=num_processes,
144-
gpus=get_gpu_devices(),
144+
gpus=gpus,
145145
accelerator="ddp",
146146
logger=logger,
147147
max_epochs=args.epochs,

0 commit comments

Comments
 (0)