File tree Expand file tree Collapse file tree 1 file changed +13
-13
lines changed
torchx/examples/apps/lightning_classy_vision Expand file tree Collapse file tree 1 file changed +13
-13
lines changed Original file line number Diff line number Diff line change 27
27
import torch
28
28
from pytorch_lightning .callbacks import ModelCheckpoint
29
29
from pytorch_lightning .loggers import TensorBoardLogger
30
-
31
- # ensure data and module are on the path
32
- sys .path .append ("." )
33
-
34
30
from torchx .examples .apps .lightning_classy_vision .data import (
35
31
TinyImageNetDataModule ,
36
- download_data ,
37
32
create_random_data ,
33
+ download_data ,
38
34
)
39
35
from torchx .examples .apps .lightning_classy_vision .model import (
40
36
TinyImageNetModel ,
41
37
export_inference_model ,
42
38
)
43
- from torchx .examples .apps .lightning_classy_vision .profiler import (
44
- SimpleLoggingProfiler ,
45
- )
39
+ from torchx .examples .apps .lightning_classy_vision .profiler import SimpleLoggingProfiler
40
+
41
+
42
+ # ensure data and module are on the path
43
+ sys .path .append ("." )
46
44
47
45
48
46
def parse_args (argv : List [str ]) -> argparse .Namespace :
@@ -84,10 +82,6 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
84
82
return parser .parse_args (argv )
85
83
86
84
87
- def get_gpu_devices () -> int :
88
- return torch .cuda .device_count ()
89
-
90
-
91
85
def get_model_checkpoint (args : argparse .Namespace ) -> Optional [ModelCheckpoint ]:
92
86
if not args .output_path :
93
87
return None
@@ -138,10 +132,16 @@ def main(argv: List[str]) -> None:
138
132
# Initialize a trainer
139
133
num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 ))
140
134
num_processes = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 ))
135
+
136
+ if torch .cuda .is_available ():
137
+ gpus = num_processes
138
+ else :
139
+ gpus = None
140
+
141
141
trainer = pl .Trainer (
142
142
num_nodes = num_nodes ,
143
143
num_processes = num_processes ,
144
- gpus = get_gpu_devices () ,
144
+ gpus = gpus ,
145
145
accelerator = "ddp" ,
146
146
logger = logger ,
147
147
max_epochs = args .epochs ,
You can’t perform that action at this time.
0 commit comments