Skip to content

Commit 546dd54

Browse files
authored
Warp the job command using parentheses. (#2495)
* Warp the job command using parentheses * Wrap the bash command * Fix to redirect logs * Wrap the job command * Wrap the job command
1 parent 915acff commit 546dd54

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

elasticdl/python/master/pod_manager.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,10 @@ def _process_worker(self):
321321

322322
def _start_worker(self, worker_id):
323323
logger.info("Starting worker: %d" % worker_id)
324-
job_command = self._complement_job_command()
324+
if self._ps_addrs and self._need_elasticdl_job_args:
325+
# worker_args[1] is the execution command of the worker
326+
self._worker_args[1] += " --ps_addrs {}".format(self._ps_addrs)
327+
job_command = self._complement_job_command(self._worker_args)
325328
worker_args = [self._worker_args[0], job_command]
326329
envs = copy.deepcopy(self._envs)
327330
envs.append(V1EnvVar(name=WorkerEnv.WORKER_ID, value=str(worker_id)))
@@ -366,31 +369,25 @@ def _start_worker(self, worker_id):
366369

367370
return True
368371

369-
def _complement_job_command(self):
370-
# self._worker_args has 2 strings. The first string is "-c" and
372+
def _complement_job_command(self, pod_args):
373+
# pod_args has 2 strings. The first string is "-c" and
371374
# the second string is the shell command to run, like
372-
# ["-c", "python -m elasticdl.python.worker.main --minibatch_size 64"]
373-
job_command = self._worker_args[1]
374-
if self._ps_addrs and self._need_elasticdl_job_args:
375-
job_command += " --ps_addrs {}".format(self._ps_addrs)
375+
# ["-c", "python -m main --minibatch_size 64"]
376+
job_command = pod_args[1]
376377
if self._log_file_path:
377-
job_command += BashCommandTemplate.REDIRECTION.format(
378-
self._log_file_path
378+
job_command = BashCommandTemplate.REDIRECTION.format(
379+
job_command, self._log_file_path
379380
)
380-
job_command += " ".join(self._worker_args[2:])
381+
job_command += " ".join(pod_args[2:])
381382
job_command = BashCommandTemplate.SET_PIPEFAIL + job_command
382383
return job_command
383384

384385
def _start_ps(self, ps_id):
385386
logger.info("Starting PS: %d" % ps_id)
386-
bash_command = self._ps_args[1]
387387
if self._need_elasticdl_job_args:
388-
bash_command += " --ps_id {}".format(ps_id)
389-
if self._log_file_path:
390-
bash_command += BashCommandTemplate.REDIRECTION.format(
391-
self._log_file_path
392-
)
393-
ps_args = [self._ps_args[0], bash_command]
388+
self._ps_args[1] += " --ps_id {}".format(ps_id)
389+
job_command = self._complement_job_command(self._ps_args)
390+
ps_args = [self._ps_args[0], job_command]
394391
while True:
395392
with self._lock:
396393
pod = self._create_ps_pod(ps_id, ps_args)

elasticdl_client/common/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class DistributionStrategy(object):
1919

2020

2121
class BashCommandTemplate(object):
22-
REDIRECTION = " 2>&1 | tee {}"
22+
REDIRECTION = "({}) 2>&1 | tee {}"
2323
SET_PIPEFAIL = "set -o pipefail;"
2424

2525

0 commit comments

Comments
 (0)