Skip to content

Commit 97eeedc

Browse files
Refactor the common modules between main package and elasticdl_client. (#2100)
* Remove duplicated k8s_resource and k8s_volume * Remove api and client module, and duplicated method in args.py from main package * Remove the entry point in elasticdl main package and rename elasticdl_client entry to elasticdl. * Add dependency for elasticdl_client in main package. * Make k8s_client in main package inherit from k8s_client from client package. * Revert setup.py and setup_client.py * Update the duplicated module to be imported from main package to elasticdl_client package. * Fix the inheritance issue. * Revert the changes in build_and_test.sh * Keep JobType in elasticdl main package. * Set default value for image_name argument. * Rename the console entry of elasticdl_client package to elasticdl. * Use python ps for evaluate and predict job in integration test. * Revert the change that use python ps in predict/evaluate integration test.
1 parent 313c6b0 commit 97eeedc

25 files changed

+121
-1413
lines changed

elasticdl/python/common/args.py

Lines changed: 6 additions & 553 deletions
Large diffs are not rendered by default.

elasticdl/python/common/constants.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ class MetricsDictKey(object):
5555
LABEL = "label"
5656

5757

58-
class DistributionStrategy(object):
59-
LOCAL = "Local"
60-
PARAMETER_SERVER = "ParameterServerStrategy"
61-
ALLREDUCE = "AllreduceStrategy"
62-
63-
6458
class SaveModelConfig(object):
6559
SAVED_MODEL_PATH = "saved_model_path"
6660

@@ -86,8 +80,3 @@ class ReaderType(object):
8680
CSV_READER = "CSV"
8781
ODPS_READER = "ODPS"
8882
RECORDIO_READER = "RecordIO"
89-
90-
91-
class BashCommandTemplate(object):
92-
REDIRECTION = " 2>&1 | tee {}"
93-
SET_PIPEFAIL = "set -o pipefail;"

elasticdl/python/common/k8s_client.py

Lines changed: 17 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,24 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import os
1514
import threading
1615
import time
17-
import traceback
1816

19-
import yaml
20-
from kubernetes import client, config, watch
21-
from kubernetes.client import V1EnvVar, V1EnvVarSource, V1ObjectFieldSelector
17+
from kubernetes import client, watch
2218

23-
from elasticdl.python.common.k8s_resource import parse as parse_resource
24-
from elasticdl.python.common.k8s_volume import parse_volume_and_mount
2519
from elasticdl.python.common.log_utils import default_logger as logger
26-
from elasticdl.python.common.model_utils import load_module
20+
from elasticdl_client.common.k8s_client import (
21+
ELASTICDL_APP_NAME,
22+
ELASTICDL_JOB_KEY,
23+
ELASTICDL_REPLICA_INDEX_KEY,
24+
ELASTICDL_REPLICA_TYPE_KEY,
25+
)
26+
from elasticdl_client.common.k8s_client import Client as BaseClient
27+
from elasticdl_client.common.k8s_client import append_pod_ip_to_env
2728

28-
ELASTICDL_APP_NAME = "elasticdl"
29-
ELASTICDL_JOB_KEY = "elasticdl-job-name"
30-
ELASTICDL_REPLICA_TYPE_KEY = "elasticdl-replica-type"
31-
ELASTICDL_REPLICA_INDEX_KEY = "elasticdl-replica-index"
3229
_PS_SERVICE_PORT = 2222
3330
_WORKER_SERVICE_PORT = 3333
3431
_FTLIB_GOSSIP_CONTAINER_PORT = 7946
35-
_FTLIB_SSH_CONTAINER_PORT = 22
36-
37-
38-
def get_master_pod_name(job_name):
39-
return "elasticdl-%s-master" % job_name
4032

4133

4234
def get_worker_pod_name(job_name, worker_id):
@@ -47,21 +39,7 @@ def get_ps_pod_name(job_name, ps_id):
4739
return "elasticdl-%s-ps-%s" % (job_name, str(ps_id))
4840

4941

50-
def append_pod_ip_to_env(env):
51-
pod_ip_var = V1EnvVar(
52-
name="MY_POD_IP",
53-
value_from=V1EnvVarSource(
54-
field_ref=V1ObjectFieldSelector(field_path="status.podIP")
55-
),
56-
)
57-
if env:
58-
env.append(pod_ip_var)
59-
else:
60-
env = [pod_ip_var]
61-
return env
62-
63-
64-
class Client(object):
42+
class Client(BaseClient):
6543
def __init__(
6644
self,
6745
*,
@@ -88,37 +66,18 @@ def __init__(
8866
running in a K8S environment, it loads the incluster config,
8967
if not, it loads the kube config file.
9068
"""
91-
try:
92-
if (
93-
os.getenv("KUBERNETES_SERVICE_HOST")
94-
and not force_use_kube_config_file
95-
):
96-
# We are running inside a k8s cluster
97-
config.load_incluster_config()
98-
logger.info("Load the incluster config.")
99-
else:
100-
# Use user's kube config
101-
config.load_kube_config()
102-
logger.info("Load the kube config file.")
103-
except Exception as ex:
104-
traceback.print_exc()
105-
raise Exception(
106-
"Failed to load configuration for Kubernetes:\n%s" % str(ex)
107-
)
108-
109-
self.client = client.CoreV1Api()
110-
self.namespace = namespace
111-
self.job_name = job_name
112-
self._image_name = image_name
69+
super().__init__(
70+
image_name=image_name,
71+
namespace=namespace,
72+
job_name=job_name,
73+
cluster_spec=cluster_spec,
74+
force_use_kube_config_file=force_use_kube_config_file,
75+
)
11376
self._event_cb = event_callback
11477
if self._event_cb:
11578
threading.Thread(
11679
target=self._watch, name="event_watcher", daemon=True
11780
).start()
118-
self.cluster = None
119-
if cluster_spec:
120-
cluster_spec_module = load_module(cluster_spec)
121-
self.cluster = cluster_spec_module.cluster
12281

12382
def _watch(self):
12483
while True:
@@ -139,9 +98,6 @@ def _watch(self):
13998
def _get_service_address(self, service_name, port):
14099
return "%s.%s.svc:%d" % (service_name, self.namespace, port)
141100

142-
def get_master_pod_name(self):
143-
return get_master_pod_name(self.job_name)
144-
145101
def get_worker_pod_name(self, worker_id):
146102
return get_worker_pod_name(self.job_name, worker_id)
147103

@@ -164,16 +120,6 @@ def get_ps_service_address(self, ps_id):
164120
self.get_ps_service_name(ps_id), _PS_SERVICE_PORT
165121
)
166122

167-
def patch_labels_to_pod(self, pod_name, labels_dict):
168-
body = {"metadata": {"labels": labels_dict}}
169-
try:
170-
return self.client.patch_namespaced_pod(
171-
name=pod_name, namespace=self.namespace, body=body
172-
)
173-
except client.api_client.ApiException as e:
174-
logger.warning("Exception when patching labels to pod: %s\n" % e)
175-
return None
176-
177123
def get_master_pod(self):
178124
return self.get_pod(self.get_master_pod_name())
179125

@@ -216,129 +162,6 @@ def get_worker_service(self, worker_id):
216162
logger.warning("Exception when reading worker service: %s\n" % e)
217163
return None
218164

219-
@staticmethod
220-
def create_owner_reference(owner_pod):
221-
owner_ref = (
222-
[
223-
client.V1OwnerReference(
224-
api_version="v1",
225-
block_owner_deletion=True,
226-
kind="Pod",
227-
name=owner_pod.metadata.name,
228-
uid=owner_pod.metadata.uid,
229-
)
230-
]
231-
if owner_pod
232-
else None
233-
)
234-
return owner_ref
235-
236-
def create_pod(self, **kargs):
237-
# Container
238-
pod_resource_requests = kargs["resource_requests"]
239-
pod_resource_limits = kargs["resource_limits"]
240-
pod_resource_limits = (
241-
pod_resource_limits
242-
if pod_resource_limits
243-
else pod_resource_requests
244-
)
245-
ports = (
246-
[
247-
client.V1ContainerPort(
248-
container_port=_FTLIB_GOSSIP_CONTAINER_PORT, name="gossip"
249-
),
250-
]
251-
if "expose_ports" in kargs and kargs["expose_ports"]
252-
else None
253-
)
254-
container = client.V1Container(
255-
name=kargs["pod_name"],
256-
image=kargs["image_name"],
257-
command=kargs["command"],
258-
resources=client.V1ResourceRequirements(
259-
requests=parse_resource(pod_resource_requests),
260-
limits=parse_resource(pod_resource_limits),
261-
),
262-
args=kargs["container_args"],
263-
image_pull_policy=kargs["image_pull_policy"],
264-
env=kargs["env"],
265-
ports=ports,
266-
)
267-
268-
# Pod
269-
spec = client.V1PodSpec(
270-
containers=[container],
271-
restart_policy=kargs["restart_policy"],
272-
priority_class_name=kargs["pod_priority"],
273-
termination_grace_period_seconds=kargs.get(
274-
"termination_period", None
275-
),
276-
)
277-
278-
# Mount data path
279-
if kargs["volume"]:
280-
volumes, volume_mounts = parse_volume_and_mount(
281-
kargs["volume"], kargs["pod_name"]
282-
)
283-
spec.volumes = volumes
284-
container.volume_mounts = volume_mounts
285-
286-
pod = client.V1Pod(
287-
spec=spec,
288-
metadata=client.V1ObjectMeta(
289-
name=kargs["pod_name"],
290-
labels=self._get_common_labels(),
291-
owner_references=self.create_owner_reference(
292-
kargs["owner_pod"]
293-
),
294-
namespace=self.namespace,
295-
),
296-
)
297-
if self.cluster:
298-
pod = self.cluster.with_pod(pod)
299-
300-
return pod
301-
302-
def create_master(self, **kargs):
303-
pod = self._create_master_pod_obj(**kargs)
304-
self.client.create_namespaced_pod(self.namespace, pod)
305-
logger.info("Master launched.")
306-
307-
def dump_master_yaml(self, **kargs):
308-
pod = self._create_master_pod_obj(**kargs)
309-
pod_dict = self.client.api_client.sanitize_for_serialization(pod)
310-
with open(kargs["yaml"], "w") as f:
311-
yaml.safe_dump(pod_dict, f, default_flow_style=False)
312-
313-
def _create_master_pod_obj(self, **kargs):
314-
env = []
315-
if "envs" in kargs:
316-
for key in kargs["envs"]:
317-
env.append(V1EnvVar(name=key, value=kargs["envs"][key]))
318-
env = append_pod_ip_to_env(env)
319-
320-
pod = self.create_pod(
321-
pod_name=self.get_master_pod_name(),
322-
job_name=self.job_name,
323-
image_name=self._image_name,
324-
command=["/bin/bash"],
325-
resource_requests=kargs["resource_requests"],
326-
resource_limits=kargs["resource_limits"],
327-
container_args=kargs["args"],
328-
pod_priority=kargs["pod_priority"],
329-
image_pull_policy=kargs["image_pull_policy"],
330-
restart_policy=kargs["restart_policy"],
331-
volume=kargs["volume"],
332-
owner_pod=None,
333-
env=env,
334-
)
335-
# Add replica type and index
336-
pod.metadata.labels[ELASTICDL_REPLICA_TYPE_KEY] = "master"
337-
pod.metadata.labels[ELASTICDL_REPLICA_INDEX_KEY] = "0"
338-
pod.api_version = "v1"
339-
pod.kind = "Pod"
340-
return pod
341-
342165
def _create_ps_worker_pod(self, pod_name, type_key, index_key, **kargs):
343166
# Find that master pod that will be used as the owner reference
344167
# for the ps or worker pod.
@@ -380,10 +203,6 @@ def create_ps(self, **kargs):
380203
pod_name, "ps", kargs["ps_id"], **kargs
381204
)
382205

383-
def delete_master(self):
384-
logger.info("pod name is %s" % self.get_master_pod_name())
385-
self.delete_pod(self.get_master_pod_name())
386-
387206
def delete_worker(self, worker_id):
388207
self.delete_pod(self.get_worker_pod_name(worker_id))
389208

@@ -489,12 +308,6 @@ def _create_service(self, **kargs):
489308
service = self.cluster.with_service(service)
490309
return self.client.create_namespaced_service(self.namespace, service)
491310

492-
def _get_common_labels(self):
493-
"""Labels that should be attached to all k8s objects belong to
494-
current job.
495-
"""
496-
return {"app": ELASTICDL_APP_NAME, ELASTICDL_JOB_KEY: self.job_name}
497-
498311
def get_master_log(self):
499312
return self.get_pod_log(self.get_master_pod_name())
500313

0 commit comments

Comments
 (0)