Skip to content

Commit 7bdfebf

Browse files
authored
Relaunch worker on failure (#2485)
* relaunch worker on failure * only relaunch in PS strategy
1 parent 4c5e7d8 commit 7bdfebf

File tree

3 files changed

+50
-15
lines changed

3 files changed

+50
-15
lines changed

elasticdl/python/master/pod_manager.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from elasticdl_client.common.constants import (
3737
BashCommandTemplate,
3838
ClusterSpecConfig,
39+
DistributionStrategy,
3940
)
4041
from elasticdl_client.common.k8s_client import (
4142
ELASTICDL_REPLICA_INDEX_KEY,
@@ -45,6 +46,16 @@
4546
_SERVICE_ADDR_SEP = ","
4647

4748

49+
class WorkerInfo(object):
50+
def __init__(self, pod_priority="", original_index=0, relaunch_count=0):
51+
self.pod_priority = pod_priority
52+
self.original_index = original_index
53+
self.relaunch_count = relaunch_count
54+
55+
def inc_relaunch_count(self):
56+
self.relaunch_count += 1
57+
58+
4859
def _get_addrs(num_addrs, addr_get_fn):
4960
"""
5061
Get `num_addrs` addresses and then concatenate
@@ -159,6 +170,12 @@ def create_pod_manager(args):
159170
disable_relaunch = kwargs.get("disable_relaunch", False)
160171
cluster_spec = get_image_cluster_spec(args.cluster_spec)
161172

173+
# relaunch on worker failure for PS strategy only
174+
if args.distribution_strategy == DistributionStrategy.PARAMETER_SERVER:
175+
relaunch_on_worker_failure = args.relaunch_on_worker_failure
176+
else:
177+
relaunch_on_worker_failure = 0
178+
162179
pod_manager = PodManager(
163180
job_name=args.job_name,
164181
image_name=args.worker_image,
@@ -181,6 +198,7 @@ def create_pod_manager(args):
181198
disable_relaunch=disable_relaunch,
182199
log_file_path=args.log_file_path,
183200
need_elasticdl_job_args=args.need_elasticdl_job_service,
201+
relaunch_on_worker_failure=relaunch_on_worker_failure,
184202
)
185203

186204
return pod_manager
@@ -205,6 +223,7 @@ def __init__(
205223
disable_relaunch=False,
206224
log_file_path=None,
207225
need_elasticdl_job_args=False,
226+
relaunch_on_worker_failure=0,
208227
**kwargs
209228
):
210229
self._num_workers = num_workers
@@ -213,9 +232,11 @@ def __init__(
213232
worker_pod_priority = _parse_worker_pod_priority(
214233
self._num_workers, worker_pod_priority
215234
)
216-
self._worker_pod_priority_and_original_index = {}
235+
self._worker_info = {}
217236
for (k, v) in worker_pod_priority.items():
218-
self._worker_pod_priority_and_original_index[k] = (v, k)
237+
self._worker_info[k] = WorkerInfo(
238+
pod_priority=v, original_index=k, relaunch_count=0
239+
)
219240

220241
self._num_ps = num_ps
221242
self._ps_resource_request = ps_resource_request
@@ -230,6 +251,7 @@ def __init__(
230251
self._log_file_path = log_file_path
231252
self._need_tf_config = need_tf_config
232253
self._need_elasticdl_job_args = need_elasticdl_job_args
254+
self._relaunch_on_worker_failure = relaunch_on_worker_failure
233255

234256
# Protects followed variables, which are accessed from event_cb.
235257
self._lock = threading.Lock()
@@ -307,9 +329,7 @@ def _start_worker(self, worker_id):
307329
need_patch_service = False
308330
original_index = worker_id
309331
if self._need_tf_config:
310-
original_index = self._worker_pod_priority_and_original_index[
311-
worker_id
312-
][1]
332+
original_index = self._worker_info[worker_id].original_index
313333
tf_config = self.get_tf_config_data(PodType.WORKER, original_index)
314334
envs.append(
315335
V1EnvVar(name="TF_CONFIG", value=json.dumps(tf_config))
@@ -323,9 +343,7 @@ def _start_worker(self, worker_id):
323343
worker_id=worker_id,
324344
resource_requests=self._worker_resource_request,
325345
resource_limits=self._worker_resource_limit,
326-
pod_priority=self._worker_pod_priority_and_original_index[
327-
worker_id
328-
][0],
346+
pod_priority=self._worker_info[worker_id].pod_priority,
329347
termination_period=1,
330348
volume=self._volume,
331349
image_pull_policy=self._image_pull_policy,
@@ -559,9 +577,18 @@ def _event_cb(self, event):
559577
callback.on_pod_failed(pod_info, cluster_context)
560578
for callback in self._pod_event_callbacks
561579
]
562-
should_relaunch = should_relaunch and _should_relaunch_killed_pod(
563-
evt_obj=evt_obj
564-
)
580+
if should_relaunch:
581+
should_relaunch = (
582+
should_relaunch
583+
and _should_relaunch_killed_pod(evt_obj=evt_obj)
584+
)
585+
if (
586+
not should_relaunch
587+
and self._worker_info[pod_id].relaunch_count
588+
< self._relaunch_on_worker_failure
589+
):
590+
self._worker_info[pod_id].inc_relaunch_count()
591+
should_relaunch = True
565592
elif matched_pod_state_flow.to_status == PodStatus.DELETED:
566593
[
567594
callback.on_pod_deleted(pod_info, cluster_context)
@@ -573,9 +600,7 @@ def _event_cb(self, event):
573600

574601
new_worker_id = self._next_worker_id_fn()
575602
with self._lock:
576-
self._worker_pod_priority_and_original_index[
577-
new_worker_id
578-
] = self._worker_pod_priority_and_original_index[pod_id]
603+
self._worker_info[new_worker_id] = self._worker_info[pod_id]
579604
self._start_worker(new_worker_id)
580605

581606
@property

elasticdl/python/tests/pod_manager_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from elasticdl.python.master.pod_manager import (
3535
PodManager,
36+
WorkerInfo,
3637
_parse_worker_pod_priority,
3738
build_environment_variables,
3839
)
@@ -67,7 +68,9 @@ def test_create_delete_worker_pod(self):
6768
break
6869

6970
pod_manager._not_created_worker_id = [2]
70-
pod_manager._worker_pod_priority_and_original_index[2] = (None, 1)
71+
pod_manager._worker_info[2] = WorkerInfo(
72+
pod_priority=None, original_index=1, relaunch_count=0
73+
)
7174
pod_manager._process_worker()
7275
for _ in range(max_check_num):
7376
time.sleep(3)

elasticdl_client/common/args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def add_train_params(parser):
196196
help="If true, needs to set TF_CONFIG env for ps/worker. Also "
197197
"need to use fixed service name for workers",
198198
)
199+
parser.add_argument(
200+
"--relaunch_on_worker_failure",
201+
type=int,
202+
help="The number of relaunch tries for a worker failure for "
203+
"PS Strategy training",
204+
default=1,
205+
)
199206

200207

201208
def add_evaluate_params(parser):

0 commit comments

Comments
 (0)