Skip to content

Commit cfbfcf2

Browse files
Add elasticdl_ps file path into PATH environment variable before executing the binary in the PS pod. (#2137)
* For the PS pod command, add the path of elasticdl_ps into PATH environment variable and then execute elasticdl_ps. * Fix the incorrect usage of extend method which doesn't have return value.
1 parent 5fae532 commit cfbfcf2

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

elasticdl/python/master/master.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ def _create_instance_manager(self, args):
409409

410410
if args.use_go_ps:
411411
opt_type, opt_args = get_optimizer_info(self.optimizer)
412-
ps_client_command = "elasticdl_ps"
413-
ps_args = [
412+
ps_command = "elasticdl_ps"
413+
ps_command_args = [
414414
"-job_name=" + args.job_name,
415415
"-namespace=" + args.namespace,
416416
"-master_addr=" + self.master_addr,
@@ -432,14 +432,22 @@ def _create_instance_manager(self, args):
432432
"-opt_type=" + opt_type,
433433
"-opt_args=" + opt_args,
434434
]
435-
ps_args = wrap_go_args_with_string(ps_args)
436-
ps_args.insert(0, ps_client_command)
435+
ps_command_args = wrap_go_args_with_string(ps_command_args)
436+
# Execute source /root/.bashrc to add the file path
437+
# of `elasticdl_ps` into the PATH environment variable.
438+
ps_args = [
439+
"source",
440+
"/root/.bashrc",
441+
"&&",
442+
ps_command,
443+
]
444+
ps_args.extend(ps_command_args)
437445
else:
438-
ps_client_command = (
446+
ps_command = (
439447
BashCommandTemplate.SET_PIPEFAIL
440448
+ " python -m elasticdl.python.ps.main"
441449
)
442-
ps_args = [
450+
ps_command_args = [
443451
"--grads_to_wait",
444452
str(args.grads_to_wait),
445453
"--lr_staleness_modulation",
@@ -481,8 +489,8 @@ def _create_instance_manager(self, args):
481489
"--num_minibatches_per_task",
482490
str(args.num_minibatches_per_task),
483491
]
484-
ps_args = wrap_python_args_with_string(ps_args)
485-
ps_args.insert(0, ps_client_command)
492+
ps_args = wrap_python_args_with_string(ps_command_args)
493+
ps_args.insert(0, ps_command)
486494

487495
worker_args = ["-c", " ".join(worker_args)]
488496
ps_args = ["-c", " ".join(ps_args)]

0 commit comments

Comments
 (0)