@@ -321,7 +321,10 @@ def _process_worker(self):
321
321
322
322
def _start_worker (self , worker_id ):
323
323
logger .info ("Starting worker: %d" % worker_id )
324
- job_command = self ._complement_job_command ()
324
+ if self ._ps_addrs and self ._need_elasticdl_job_args :
325
+ # worker_args[1] is the execution command of the worker
326
+ self ._worker_args [1 ] += " --ps_addrs {}" .format (self ._ps_addrs )
327
+ job_command = self ._complement_job_command (self ._worker_args )
325
328
worker_args = [self ._worker_args [0 ], job_command ]
326
329
envs = copy .deepcopy (self ._envs )
327
330
envs .append (V1EnvVar (name = WorkerEnv .WORKER_ID , value = str (worker_id )))
@@ -366,31 +369,25 @@ def _start_worker(self, worker_id):
366
369
367
370
return True
368
371
369
- def _complement_job_command (self ):
370
- # self._worker_args has 2 strings. The first string is "-c" and
372
+ def _complement_job_command (self , pod_args ):
373
+ # pod_args has 2 strings. The first string is "-c" and
371
374
# the second string is the shell command to run, like
372
- # ["-c", "python -m elasticdl.python.worker.main --minibatch_size 64"]
373
- job_command = self ._worker_args [1 ]
374
- if self ._ps_addrs and self ._need_elasticdl_job_args :
375
- job_command += " --ps_addrs {}" .format (self ._ps_addrs )
375
+ # ["-c", "python -m main --minibatch_size 64"]
376
+ job_command = pod_args [1 ]
376
377
if self ._log_file_path :
377
- job_command + = BashCommandTemplate .REDIRECTION .format (
378
- self ._log_file_path
378
+ job_command = BashCommandTemplate .REDIRECTION .format (
379
+ job_command , self ._log_file_path
379
380
)
380
- job_command += " " .join (self . _worker_args [2 :])
381
+ job_command += " " .join (pod_args [2 :])
381
382
job_command = BashCommandTemplate .SET_PIPEFAIL + job_command
382
383
return job_command
383
384
384
385
def _start_ps (self , ps_id ):
385
386
logger .info ("Starting PS: %d" % ps_id )
386
- bash_command = self ._ps_args [1 ]
387
387
if self ._need_elasticdl_job_args :
388
- bash_command += " --ps_id {}" .format (ps_id )
389
- if self ._log_file_path :
390
- bash_command += BashCommandTemplate .REDIRECTION .format (
391
- self ._log_file_path
392
- )
393
- ps_args = [self ._ps_args [0 ], bash_command ]
388
+ self ._ps_args [1 ] += " --ps_id {}" .format (ps_id )
389
+ job_command = self ._complement_job_command (self ._ps_args )
390
+ ps_args = [self ._ps_args [0 ], job_command ]
394
391
while True :
395
392
with self ._lock :
396
393
pod = self ._create_ps_pod (ps_id , ps_args )
0 commit comments