1- from typing import Callable , Dict
1+ from typing import Callable , Dict , List
22
33import os
44from collections import defaultdict
77import torch
88from pytorch_lightning .plugins import DDPSpawnPlugin
99from pytorch_lightning import _logger as log , LightningModule
10- from ray .util .sgd .torch . utils import setup_address
10+ from ray .util .sgd .utils import find_free_port
1111
1212from ray_lightning .session import init_session
1313from ray_lightning .util import process_results , Queue
@@ -20,7 +20,15 @@ class RayExecutor:
2020
2121 def set_env_var (self , key : str , value : str ):
2222 """Set an environment variable with the provided values."""
23- os .environ [key ] = value
23+ if value is not None :
24+ value = str (value )
25+ os .environ [key ] = value
26+
27+ def set_env_vars (self , keys : List [str ], values : List [str ]):
28+ """Sets multiple env vars with the provided values"""
29+ assert len (keys ) == len (values )
30+ for key , value in zip (keys , values ):
31+ self .set_env_var (key , value )
2432
2533 def get_node_ip (self ):
2634 """Returns the IP address of the node that this Ray actor is on."""
@@ -137,16 +145,19 @@ def start_training(self, trainer):
137145 revieve intermediate results, and process those results. Finally
138146 retrieve the training results from the rank 0 worker and return."""
139147
140- if "PL_GLOBAL_SEED" in os .environ :
141- seed = os .environ ["PL_GLOBAL_SEED" ]
142- ray .get ([
143- w .set_env_var .remote ("PL_GLOBAL_SEED" , seed )
144- for w in self .workers
145- ])
148+ # Get rank 0 worker address and port for DDP connection.
149+ os .environ ["MASTER_ADDR" ] = ray .get (
150+ self .workers [0 ].get_node_ip .remote ())
151+ os .environ ["MASTER_PORT" ] = str (
152+ ray .get (self .workers [0 ].execute .remote (find_free_port )))
146153
147- # Get the rank 0 address for DDP connection.
148- self .ddp_address = ray .get (
149- self .workers [0 ].execute .remote (setup_address ))
154+ # Set environment variables for remote workers.
155+ keys = [
156+ "PL_GLOBAL_SEED" , "PL_TORCH_DISTRIBUTED_BACKEND" , "MASTER_ADDR" ,
157+ "MASTER_PORT"
158+ ]
159+ values = [os .getenv (k ) for k in keys ]
160+ ray .get ([w .set_env_vars .remote (keys , values ) for w in self .workers ])
150161
151162 self .global_to_local = self .get_local_ranks ()
152163
@@ -235,14 +246,15 @@ def init_ddp_connection(self,
235246 world_size : int ,
236247 is_slurm_managing_tasks : bool = False ) -> None :
237248 """Process group creation to be executed on each remote worker."""
238- torch_backend = "nccl" if self .use_gpu else "gloo"
249+ torch_backend = os .getenv ("PL_TORCH_DISTRIBUTED_BACKEND" )
250+ if torch_backend is None :
251+ torch_backend = "nccl" if self .use_gpu else "gloo"
239252
240253 if not torch .distributed .is_initialized ():
241254 log .info (f"initializing ddp: GLOBAL_RANK: { global_rank } , MEMBER:"
242255 f" { global_rank + 1 } /{ world_size } " )
243256 torch .distributed .init_process_group (
244257 backend = torch_backend ,
245- init_method = self .ddp_address ,
246258 rank = global_rank ,
247259 world_size = world_size ,
248260 )
0 commit comments