11
11
# See the License for the specific language governing permissions and
12
12
# limitations under the License.
13
13
14
- import os
15
14
import threading
16
15
import time
17
- import traceback
18
16
19
- import yaml
20
- from kubernetes import client , config , watch
21
- from kubernetes .client import V1EnvVar , V1EnvVarSource , V1ObjectFieldSelector
17
+ from kubernetes import client , watch
22
18
23
- from elasticdl .python .common .k8s_resource import parse as parse_resource
24
- from elasticdl .python .common .k8s_volume import parse_volume_and_mount
25
19
from elasticdl .python .common .log_utils import default_logger as logger
26
- from elasticdl .python .common .model_utils import load_module
20
+ from elasticdl_client .common .k8s_client import (
21
+ ELASTICDL_APP_NAME ,
22
+ ELASTICDL_JOB_KEY ,
23
+ ELASTICDL_REPLICA_INDEX_KEY ,
24
+ ELASTICDL_REPLICA_TYPE_KEY ,
25
+ )
26
+ from elasticdl_client .common .k8s_client import Client as BaseClient
27
+ from elasticdl_client .common .k8s_client import append_pod_ip_to_env
27
28
28
- ELASTICDL_APP_NAME = "elasticdl"
29
- ELASTICDL_JOB_KEY = "elasticdl-job-name"
30
- ELASTICDL_REPLICA_TYPE_KEY = "elasticdl-replica-type"
31
- ELASTICDL_REPLICA_INDEX_KEY = "elasticdl-replica-index"
32
29
_PS_SERVICE_PORT = 2222
33
30
_WORKER_SERVICE_PORT = 3333
34
31
_FTLIB_GOSSIP_CONTAINER_PORT = 7946
35
- _FTLIB_SSH_CONTAINER_PORT = 22
36
-
37
-
38
- def get_master_pod_name (job_name ):
39
- return "elasticdl-%s-master" % job_name
40
32
41
33
42
34
def get_worker_pod_name (job_name , worker_id ):
@@ -47,21 +39,7 @@ def get_ps_pod_name(job_name, ps_id):
47
39
return "elasticdl-%s-ps-%s" % (job_name , str (ps_id ))
48
40
49
41
50
- def append_pod_ip_to_env (env ):
51
- pod_ip_var = V1EnvVar (
52
- name = "MY_POD_IP" ,
53
- value_from = V1EnvVarSource (
54
- field_ref = V1ObjectFieldSelector (field_path = "status.podIP" )
55
- ),
56
- )
57
- if env :
58
- env .append (pod_ip_var )
59
- else :
60
- env = [pod_ip_var ]
61
- return env
62
-
63
-
64
- class Client (object ):
42
+ class Client (BaseClient ):
65
43
def __init__ (
66
44
self ,
67
45
* ,
@@ -88,37 +66,18 @@ def __init__(
88
66
running in a K8S environment, it loads the incluster config,
89
67
if not, it loads the kube config file.
90
68
"""
91
- try :
92
- if (
93
- os .getenv ("KUBERNETES_SERVICE_HOST" )
94
- and not force_use_kube_config_file
95
- ):
96
- # We are running inside a k8s cluster
97
- config .load_incluster_config ()
98
- logger .info ("Load the incluster config." )
99
- else :
100
- # Use user's kube config
101
- config .load_kube_config ()
102
- logger .info ("Load the kube config file." )
103
- except Exception as ex :
104
- traceback .print_exc ()
105
- raise Exception (
106
- "Failed to load configuration for Kubernetes:\n %s" % str (ex )
107
- )
108
-
109
- self .client = client .CoreV1Api ()
110
- self .namespace = namespace
111
- self .job_name = job_name
112
- self ._image_name = image_name
69
+ super ().__init__ (
70
+ image_name = image_name ,
71
+ namespace = namespace ,
72
+ job_name = job_name ,
73
+ cluster_spec = cluster_spec ,
74
+ force_use_kube_config_file = force_use_kube_config_file ,
75
+ )
113
76
self ._event_cb = event_callback
114
77
if self ._event_cb :
115
78
threading .Thread (
116
79
target = self ._watch , name = "event_watcher" , daemon = True
117
80
).start ()
118
- self .cluster = None
119
- if cluster_spec :
120
- cluster_spec_module = load_module (cluster_spec )
121
- self .cluster = cluster_spec_module .cluster
122
81
123
82
def _watch (self ):
124
83
while True :
@@ -139,9 +98,6 @@ def _watch(self):
139
98
def _get_service_address (self , service_name , port ):
140
99
return "%s.%s.svc:%d" % (service_name , self .namespace , port )
141
100
142
- def get_master_pod_name (self ):
143
- return get_master_pod_name (self .job_name )
144
-
145
101
def get_worker_pod_name (self , worker_id ):
146
102
return get_worker_pod_name (self .job_name , worker_id )
147
103
@@ -164,16 +120,6 @@ def get_ps_service_address(self, ps_id):
164
120
self .get_ps_service_name (ps_id ), _PS_SERVICE_PORT
165
121
)
166
122
167
- def patch_labels_to_pod (self , pod_name , labels_dict ):
168
- body = {"metadata" : {"labels" : labels_dict }}
169
- try :
170
- return self .client .patch_namespaced_pod (
171
- name = pod_name , namespace = self .namespace , body = body
172
- )
173
- except client .api_client .ApiException as e :
174
- logger .warning ("Exception when patching labels to pod: %s\n " % e )
175
- return None
176
-
177
123
def get_master_pod (self ):
178
124
return self .get_pod (self .get_master_pod_name ())
179
125
@@ -216,129 +162,6 @@ def get_worker_service(self, worker_id):
216
162
logger .warning ("Exception when reading worker service: %s\n " % e )
217
163
return None
218
164
219
- @staticmethod
220
- def create_owner_reference (owner_pod ):
221
- owner_ref = (
222
- [
223
- client .V1OwnerReference (
224
- api_version = "v1" ,
225
- block_owner_deletion = True ,
226
- kind = "Pod" ,
227
- name = owner_pod .metadata .name ,
228
- uid = owner_pod .metadata .uid ,
229
- )
230
- ]
231
- if owner_pod
232
- else None
233
- )
234
- return owner_ref
235
-
236
- def create_pod (self , ** kargs ):
237
- # Container
238
- pod_resource_requests = kargs ["resource_requests" ]
239
- pod_resource_limits = kargs ["resource_limits" ]
240
- pod_resource_limits = (
241
- pod_resource_limits
242
- if pod_resource_limits
243
- else pod_resource_requests
244
- )
245
- ports = (
246
- [
247
- client .V1ContainerPort (
248
- container_port = _FTLIB_GOSSIP_CONTAINER_PORT , name = "gossip"
249
- ),
250
- ]
251
- if "expose_ports" in kargs and kargs ["expose_ports" ]
252
- else None
253
- )
254
- container = client .V1Container (
255
- name = kargs ["pod_name" ],
256
- image = kargs ["image_name" ],
257
- command = kargs ["command" ],
258
- resources = client .V1ResourceRequirements (
259
- requests = parse_resource (pod_resource_requests ),
260
- limits = parse_resource (pod_resource_limits ),
261
- ),
262
- args = kargs ["container_args" ],
263
- image_pull_policy = kargs ["image_pull_policy" ],
264
- env = kargs ["env" ],
265
- ports = ports ,
266
- )
267
-
268
- # Pod
269
- spec = client .V1PodSpec (
270
- containers = [container ],
271
- restart_policy = kargs ["restart_policy" ],
272
- priority_class_name = kargs ["pod_priority" ],
273
- termination_grace_period_seconds = kargs .get (
274
- "termination_period" , None
275
- ),
276
- )
277
-
278
- # Mount data path
279
- if kargs ["volume" ]:
280
- volumes , volume_mounts = parse_volume_and_mount (
281
- kargs ["volume" ], kargs ["pod_name" ]
282
- )
283
- spec .volumes = volumes
284
- container .volume_mounts = volume_mounts
285
-
286
- pod = client .V1Pod (
287
- spec = spec ,
288
- metadata = client .V1ObjectMeta (
289
- name = kargs ["pod_name" ],
290
- labels = self ._get_common_labels (),
291
- owner_references = self .create_owner_reference (
292
- kargs ["owner_pod" ]
293
- ),
294
- namespace = self .namespace ,
295
- ),
296
- )
297
- if self .cluster :
298
- pod = self .cluster .with_pod (pod )
299
-
300
- return pod
301
-
302
- def create_master (self , ** kargs ):
303
- pod = self ._create_master_pod_obj (** kargs )
304
- self .client .create_namespaced_pod (self .namespace , pod )
305
- logger .info ("Master launched." )
306
-
307
- def dump_master_yaml (self , ** kargs ):
308
- pod = self ._create_master_pod_obj (** kargs )
309
- pod_dict = self .client .api_client .sanitize_for_serialization (pod )
310
- with open (kargs ["yaml" ], "w" ) as f :
311
- yaml .safe_dump (pod_dict , f , default_flow_style = False )
312
-
313
- def _create_master_pod_obj (self , ** kargs ):
314
- env = []
315
- if "envs" in kargs :
316
- for key in kargs ["envs" ]:
317
- env .append (V1EnvVar (name = key , value = kargs ["envs" ][key ]))
318
- env = append_pod_ip_to_env (env )
319
-
320
- pod = self .create_pod (
321
- pod_name = self .get_master_pod_name (),
322
- job_name = self .job_name ,
323
- image_name = self ._image_name ,
324
- command = ["/bin/bash" ],
325
- resource_requests = kargs ["resource_requests" ],
326
- resource_limits = kargs ["resource_limits" ],
327
- container_args = kargs ["args" ],
328
- pod_priority = kargs ["pod_priority" ],
329
- image_pull_policy = kargs ["image_pull_policy" ],
330
- restart_policy = kargs ["restart_policy" ],
331
- volume = kargs ["volume" ],
332
- owner_pod = None ,
333
- env = env ,
334
- )
335
- # Add replica type and index
336
- pod .metadata .labels [ELASTICDL_REPLICA_TYPE_KEY ] = "master"
337
- pod .metadata .labels [ELASTICDL_REPLICA_INDEX_KEY ] = "0"
338
- pod .api_version = "v1"
339
- pod .kind = "Pod"
340
- return pod
341
-
342
165
def _create_ps_worker_pod (self , pod_name , type_key , index_key , ** kargs ):
343
166
# Find that master pod that will be used as the owner reference
344
167
# for the ps or worker pod.
@@ -380,10 +203,6 @@ def create_ps(self, **kargs):
380
203
pod_name , "ps" , kargs ["ps_id" ], ** kargs
381
204
)
382
205
383
- def delete_master (self ):
384
- logger .info ("pod name is %s" % self .get_master_pod_name ())
385
- self .delete_pod (self .get_master_pod_name ())
386
-
387
206
def delete_worker (self , worker_id ):
388
207
self .delete_pod (self .get_worker_pod_name (worker_id ))
389
208
@@ -489,12 +308,6 @@ def _create_service(self, **kargs):
489
308
service = self .cluster .with_service (service )
490
309
return self .client .create_namespaced_service (self .namespace , service )
491
310
492
- def _get_common_labels (self ):
493
- """Labels that should be attached to all k8s objects belong to
494
- current job.
495
- """
496
- return {"app" : ELASTICDL_APP_NAME , ELASTICDL_JOB_KEY : self .job_name }
497
-
498
311
def get_master_log (self ):
499
312
return self .get_pod_log (self .get_master_pod_name ())
500
313
0 commit comments