36
36
from elasticdl_client .common .constants import (
37
37
BashCommandTemplate ,
38
38
ClusterSpecConfig ,
39
+ DistributionStrategy ,
39
40
)
40
41
from elasticdl_client .common .k8s_client import (
41
42
ELASTICDL_REPLICA_INDEX_KEY ,
45
46
_SERVICE_ADDR_SEP = ","
46
47
47
48
49
+ class WorkerInfo (object ):
50
+ def __init__ (self , pod_priority = "" , original_index = 0 , relaunch_count = 0 ):
51
+ self .pod_priority = pod_priority
52
+ self .original_index = original_index
53
+ self .relaunch_count = relaunch_count
54
+
55
+ def inc_relaunch_count (self ):
56
+ self .relaunch_count += 1
57
+
58
+
48
59
def _get_addrs (num_addrs , addr_get_fn ):
49
60
"""
50
61
Get `num_addrs` addresses and then concatenate
@@ -159,6 +170,12 @@ def create_pod_manager(args):
159
170
disable_relaunch = kwargs .get ("disable_relaunch" , False )
160
171
cluster_spec = get_image_cluster_spec (args .cluster_spec )
161
172
173
+ # relaunch on worker failure for PS strategy only
174
+ if args .distribution_strategy == DistributionStrategy .PARAMETER_SERVER :
175
+ relaunch_on_worker_failure = args .relaunch_on_worker_failure
176
+ else :
177
+ relaunch_on_worker_failure = 0
178
+
162
179
pod_manager = PodManager (
163
180
job_name = args .job_name ,
164
181
image_name = args .worker_image ,
@@ -181,6 +198,7 @@ def create_pod_manager(args):
181
198
disable_relaunch = disable_relaunch ,
182
199
log_file_path = args .log_file_path ,
183
200
need_elasticdl_job_args = args .need_elasticdl_job_service ,
201
+ relaunch_on_worker_failure = relaunch_on_worker_failure ,
184
202
)
185
203
186
204
return pod_manager
@@ -205,6 +223,7 @@ def __init__(
205
223
disable_relaunch = False ,
206
224
log_file_path = None ,
207
225
need_elasticdl_job_args = False ,
226
+ relaunch_on_worker_failure = 0 ,
208
227
** kwargs
209
228
):
210
229
self ._num_workers = num_workers
@@ -213,9 +232,11 @@ def __init__(
213
232
worker_pod_priority = _parse_worker_pod_priority (
214
233
self ._num_workers , worker_pod_priority
215
234
)
216
- self ._worker_pod_priority_and_original_index = {}
235
+ self ._worker_info = {}
217
236
for (k , v ) in worker_pod_priority .items ():
218
- self ._worker_pod_priority_and_original_index [k ] = (v , k )
237
+ self ._worker_info [k ] = WorkerInfo (
238
+ pod_priority = v , original_index = k , relaunch_count = 0
239
+ )
219
240
220
241
self ._num_ps = num_ps
221
242
self ._ps_resource_request = ps_resource_request
@@ -230,6 +251,7 @@ def __init__(
230
251
self ._log_file_path = log_file_path
231
252
self ._need_tf_config = need_tf_config
232
253
self ._need_elasticdl_job_args = need_elasticdl_job_args
254
+ self ._relaunch_on_worker_failure = relaunch_on_worker_failure
233
255
234
256
# Protects followed variables, which are accessed from event_cb.
235
257
self ._lock = threading .Lock ()
@@ -307,9 +329,7 @@ def _start_worker(self, worker_id):
307
329
need_patch_service = False
308
330
original_index = worker_id
309
331
if self ._need_tf_config :
310
- original_index = self ._worker_pod_priority_and_original_index [
311
- worker_id
312
- ][1 ]
332
+ original_index = self ._worker_info [worker_id ].original_index
313
333
tf_config = self .get_tf_config_data (PodType .WORKER , original_index )
314
334
envs .append (
315
335
V1EnvVar (name = "TF_CONFIG" , value = json .dumps (tf_config ))
@@ -323,9 +343,7 @@ def _start_worker(self, worker_id):
323
343
worker_id = worker_id ,
324
344
resource_requests = self ._worker_resource_request ,
325
345
resource_limits = self ._worker_resource_limit ,
326
- pod_priority = self ._worker_pod_priority_and_original_index [
327
- worker_id
328
- ][0 ],
346
+ pod_priority = self ._worker_info [worker_id ].pod_priority ,
329
347
termination_period = 1 ,
330
348
volume = self ._volume ,
331
349
image_pull_policy = self ._image_pull_policy ,
@@ -559,9 +577,18 @@ def _event_cb(self, event):
559
577
callback .on_pod_failed (pod_info , cluster_context )
560
578
for callback in self ._pod_event_callbacks
561
579
]
562
- should_relaunch = should_relaunch and _should_relaunch_killed_pod (
563
- evt_obj = evt_obj
564
- )
580
+ if should_relaunch :
581
+ should_relaunch = (
582
+ should_relaunch
583
+ and _should_relaunch_killed_pod (evt_obj = evt_obj )
584
+ )
585
+ if (
586
+ not should_relaunch
587
+ and self ._worker_info [pod_id ].relaunch_count
588
+ < self ._relaunch_on_worker_failure
589
+ ):
590
+ self ._worker_info [pod_id ].inc_relaunch_count ()
591
+ should_relaunch = True
565
592
elif matched_pod_state_flow .to_status == PodStatus .DELETED :
566
593
[
567
594
callback .on_pod_deleted (pod_info , cluster_context )
@@ -573,9 +600,7 @@ def _event_cb(self, event):
573
600
574
601
new_worker_id = self ._next_worker_id_fn ()
575
602
with self ._lock :
576
- self ._worker_pod_priority_and_original_index [
577
- new_worker_id
578
- ] = self ._worker_pod_priority_and_original_index [pod_id ]
603
+ self ._worker_info [new_worker_id ] = self ._worker_info [pod_id ]
579
604
self ._start_worker (new_worker_id )
580
605
581
606
@property
0 commit comments