@@ -50,7 +50,10 @@ def launch(self,
5050 * args : Any ,
5151 trainer : Optional ["pl.Trainer" ] = None ,
5252 ** kwargs : Any ) -> Any :
53- """Launches the function on the workers from the driver node."""
53+ """Launches the function on the workers from the driver node.
54+
55+ This function is run on the driver process.
56+ """
5457 self .setup_workers ()
5558 ray_output = self .run_function_on_workers (
5659 function , * args , trainer = trainer , ** kwargs )
@@ -66,8 +69,9 @@ def launch(self,
6669 return return_value
6770
6871 def setup_workers (self , tune_enabled : bool = True ) -> None :
69- """Creates the Ray actors and sets up PTL Trainer environment
70- on the worker nodes.
72+ """Creates the Ray actors and sets up PTL Trainer environment.
73+
74+ This function is run on the driver process.
7175 """
7276 self ._workers = [
7377 self ._create_worker () for _ in range (self ._strategy .num_workers )
@@ -99,15 +103,21 @@ def setup_workers(self, tune_enabled: bool = True) -> None:
99103 self .tune_queue = Queue (actor_options = {"num_cpus" : 0 })
100104
101105 def _create_worker (self ) -> ray .actor .ActorHandle :
102- """Creates Ray actor workers."""
106+ """Creates Ray actor workers.
107+
108+ This function is run on the driver process.
109+ """
103110 worker = RayExecutor .options (
104111 num_cpus = self ._strategy .num_cpus_per_worker ,
105112 num_gpus = self ._strategy .num_gpus_per_worker ,
106113 resources = self ._strategy .additional_resources_per_worker ).remote ()
107114 return worker
108115
109116 def teardown_workers (self ):
110- """Tears down the Ray actors and PTL Trainer environment"""
117+ """Tears down the Ray actors and PTL Trainer environment
118+
119+ This function is run on the driver process.
120+ """
111121 if self .tune_queue :
112122 # Shutdown the queue.
113123 self .tune_queue .shutdown ()
@@ -119,7 +129,8 @@ def teardown_workers(self):
119129
120130 def get_local_ranks (self ) -> List [Optional [Tuple [int , int ]]]:
121131 """Creates a mapping of global ranks to local ranks/node ranks.
122- this method is to run on the worker nodes.
132+
133+ This function is run on the driver process.
123134 """
124135 # Get the local ranks for all the workers and store as a list.
125136 # First get the IP address of each remote worker.
@@ -146,7 +157,10 @@ def get_local_ranks(self) -> List[Optional[Tuple[int, int]]]:
146157 return global_to_local
147158
148159 def _setup_env_vars (self ):
149- """Sets environment variables for all workers."""
160+ """Sets environment variables for all workers.
161+
162+ This function is run on the driver process.
163+ """
150164 # Get rank 0 worker address and port for DDP connection.
151165 os .environ ["MASTER_ADDR" ] = self ._master_addr
152166 os .environ ["MASTER_PORT" ] = self ._master_port
@@ -162,6 +176,9 @@ def _setup_env_vars(self):
162176
163177 def _share_cuda_visible_devices (self ):
164178 """Sets CUDA_VISIBLE_DEVICES on all workers.
179+
180+ This function is run on the driver process.
181+
165182 For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
166183 visible to all workers on that worker's node.
167184 This allows GPU workers on the same node to communicate with one
@@ -207,7 +224,10 @@ def run_function_on_workers(self,
207224 trainer : Optional ["pl.Trainer" ] = None ,
208225 ** kwargs : Any ):
209226 """launch a function on all workers.
210- The actual training parts are run inside `_wrapping_function`
227+
228+ This function is run on the driver process.
229+
230+ The actual training parts are run inside `_wrapping_function`
211231 """
212232 # put the model as the ray object
213233 # and remove the model temporarily from the args
@@ -240,21 +260,29 @@ def _wrapping_function(
240260 tune_queue : Queue ,
241261 ) -> Any :
242262 """Wraps the function to run on the workers.
243- `results = function(*args, **kwargs)` is where the
244- actual training parts are run.
263+
264+ This function is run on the worker process.
265+
266+ `results = function(*args, **kwargs)` is where the
267+ actual training parts are run.
245268 """
246269 self ._strategy .set_remote (True )
247270 self ._strategy .set_global_to_local (global_to_local )
248271
249- # `function` is a trainer's class method
250- # in the ray remote tasks, its object `trainer` will also
251- # be copied when the function is remoted.
272+ # `function` is a trainer's instance method
273+ # in the ray remote tasks, its bound instance `trainer`
274+ # will also be copied when the function is remoted.
275+ #
252276 # ALERT: passing the trainer as an argument of `_wrapping_function`
253- # does not fillfullied our purpose. Ray remote tasks will
277+ # does not fulfill our purpose. Ray remote tasks will
254278 # create another copy of trainer so that
255279 # `function.__self__ != trainer`, in which the side effect only
256280 # happens to `function.__self__` when running
257- # `function(*args, **kwargs)`
281+ # `function(*args, **kwargs)` (see SOLUTION below).
282+ #
283+ # SOLUTION: we find the trainer directly from `function`
284+ # by calling `function.__self__` so that we can restore
285+ # all the side effects happened to `function.__self__`
258286 trainer = function .__self__
259287 trainer .model = model_ref
260288 args = tuple ([model_ref ] + list (args [1 :]))
@@ -284,7 +312,10 @@ def _wrapping_function(
284312
285313 def _collect_rank_zero_results (self , trainer : "pl.Trainer" ,
286314 results : Any ) -> Optional ["_RayOutput" ]:
287- """Collects the results from the worker node 0."""
315+ """Collects the results from the worker node 0.
316+
317+ This function is run on the worker process.
318+ """
288319 rank_zero_debug ("Finalizing the Ray launcher environment." )
289320 checkpoint_callback = trainer .checkpoint_callback
290321 best_model_path = checkpoint_callback .best_model_path \
@@ -316,7 +347,10 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer",
316347
317348 def _recover_results_in_main_process (self , ray_output : "_RayOutput" ,
318349 trainer : "pl.Trainer" ) -> None :
319- """Recovers the results in the main process."""
350+ """Recovers the results in the main process.
351+
352+ This function is run on the worker process.
353+ """
320354 # transfer back the best path to the trainer
321355 if trainer .checkpoint_callback :
322356 trainer .checkpoint_callback .best_model_path = str (
0 commit comments