17
17
import docker
18
18
from jinja2 import Template
19
19
20
+ from elasticdl_client .common import k8s_client as k8s
21
+ from elasticdl_client .common .args import (
22
+ build_arguments_from_parsed_result ,
23
+ parse_envs ,
24
+ wrap_python_args_with_string ,
25
+ )
26
+ from elasticdl_client .common .constants import BashCommandTemplate
27
+ from elasticdl_client .common .log_utils import default_logger as logger
28
+
20
29
21
30
def init_zoo (args ):
22
- print ("Create the Dockerfile for the model zoo." )
31
+ logger . info ("Create the Dockerfile for the model zoo." )
23
32
24
33
# Copy cluster spec file to the current directory if specified
25
34
cluster_spec_path = args .cluster_spec
@@ -37,16 +46,15 @@ def init_zoo(args):
37
46
tmpl_str = """\
38
47
FROM {{ BASE_IMAGE }} as base
39
48
40
- RUN pip install elasticdl_preprocessing
41
- RUN pip install elasticdl
49
+ RUN pip install elasticdl_preprocessing\
50
+ --extra-index-url={{ EXTRA_PYPI_INDEX }}
51
+
52
+ RUN pip install elasticdl --extra-index-url={{ EXTRA_PYPI_INDEX }}
53
+ ENV PATH /usr/local/lib/python3.6/dist-packages/elasticdl/go/bin:$PATH
42
54
43
55
COPY . /model_zoo
44
- {% if EXTRA_PYPI_INDEX %}
45
- RUN pip install -r /model_zoo/requirements.txt\
46
- --extra-index-url={{ EXTRA_PYPI_INDEX }}\
47
- {% else %}\
48
56
RUN pip install -r /model_zoo/requirements.txt\
49
- {% endif % }
57
+ --extra-index-url={{ EXTRA_PYPI_INDEX } }
50
58
51
59
{% if CLUSTER_SPEC_NAME %}\
52
60
COPY ./{{ CLUSTER_SPEC_NAME }} /cluster_spec/{{ CLUSTER_SPEC_NAME }}\
@@ -59,12 +67,12 @@ def init_zoo(args):
59
67
CLUSTER_SPEC_NAME = cluster_spec_name ,
60
68
)
61
69
62
- with open ("./Dockerfile" , mode = "w+ " ) as f :
70
+ with open ("./Dockerfile" , mode = "w" ) as f :
63
71
f .write (docker_file_content )
64
72
65
73
66
74
def build_zoo (args ):
67
- print ("Build the image for the model zoo." )
75
+ logger . info ("Build the image for the model zoo." )
68
76
# Call docker api to build the image
69
77
# Validate the image name schema
70
78
client = _get_docker_client (
@@ -83,7 +91,7 @@ def build_zoo(args):
83
91
84
92
85
93
def push_zoo (args ):
86
- print ("Push the image for the model zoo." )
94
+ logger . info ("Push the image for the model zoo." )
87
95
# Call docker api to push the image to remote registry
88
96
client = _get_docker_client (
89
97
docker_base_url = args .docker_base_url ,
@@ -95,6 +103,140 @@ def push_zoo(args):
95
103
_print_docker_progress (line )
96
104
97
105
106
+ def train (args ):
107
+ container_args = [
108
+ "--worker_image" ,
109
+ args .image ,
110
+ "--model_zoo" ,
111
+ args .model_zoo ,
112
+ "--cluster_spec" ,
113
+ args .cluster_spec ,
114
+ ]
115
+
116
+ container_args .extend (
117
+ build_arguments_from_parsed_result (
118
+ args ,
119
+ filter_args = [
120
+ "model_zoo" ,
121
+ "cluster_spec" ,
122
+ "worker_image" ,
123
+ "force_use_kube_config_file" ,
124
+ "func" ,
125
+ ],
126
+ )
127
+ )
128
+
129
+ _submit_job (args .image , args , container_args )
130
+
131
+
132
+ def evaluate (args ):
133
+ container_args = [
134
+ "--worker_image" ,
135
+ args .image ,
136
+ "--model_zoo" ,
137
+ args .model_zoo ,
138
+ "--cluster_spec" ,
139
+ args .cluster_spec ,
140
+ ]
141
+ container_args .extend (
142
+ build_arguments_from_parsed_result (
143
+ args ,
144
+ filter_args = [
145
+ "model_zoo" ,
146
+ "cluster_spec" ,
147
+ "worker_image" ,
148
+ "force_use_kube_config_file" ,
149
+ "func" ,
150
+ ],
151
+ )
152
+ )
153
+
154
+ _submit_job (args .image , args , container_args )
155
+
156
+
157
+ def predict (args ):
158
+ container_args = [
159
+ "--worker_image" ,
160
+ args .image ,
161
+ "--model_zoo" ,
162
+ args .model_zoo ,
163
+ "--cluster_spec" ,
164
+ args .cluster_spec ,
165
+ ]
166
+
167
+ container_args .extend (
168
+ build_arguments_from_parsed_result (
169
+ args ,
170
+ filter_args = [
171
+ "model_zoo" ,
172
+ "cluster_spec" ,
173
+ "worker_image" ,
174
+ "force_use_kube_config_file" ,
175
+ ],
176
+ )
177
+ )
178
+
179
+ _submit_job (args .image , args , container_args )
180
+
181
+
182
+ def _submit_job (image_name , client_args , container_args ):
183
+ client = k8s .Client (
184
+ image_name = image_name ,
185
+ namespace = client_args .namespace ,
186
+ job_name = client_args .job_name ,
187
+ cluster_spec = client_args .cluster_spec ,
188
+ force_use_kube_config_file = client_args .force_use_kube_config_file ,
189
+ )
190
+
191
+ container_args = wrap_python_args_with_string (container_args )
192
+
193
+ master_client_command = (
194
+ BashCommandTemplate .SET_PIPEFAIL
195
+ + " python -m elasticdl.python.master.main"
196
+ )
197
+ container_args .insert (0 , master_client_command )
198
+ if client_args .log_file_path :
199
+ container_args .append (
200
+ BashCommandTemplate .REDIRECTION .format (client_args .log_file_path )
201
+ )
202
+
203
+ python_command = " " .join (container_args )
204
+ container_args = ["-c" , python_command ]
205
+
206
+ if client_args .yaml :
207
+ client .dump_master_yaml (
208
+ resource_requests = client_args .master_resource_request ,
209
+ resource_limits = client_args .master_resource_limit ,
210
+ args = container_args ,
211
+ pod_priority = client_args .master_pod_priority ,
212
+ image_pull_policy = client_args .image_pull_policy ,
213
+ restart_policy = client_args .restart_policy ,
214
+ volume = client_args .volume ,
215
+ envs = parse_envs (client_args .envs ),
216
+ yaml = client_args .yaml ,
217
+ )
218
+ logger .info (
219
+ "ElasticDL job %s YAML has been dumped into file %s."
220
+ % (client_args .job_name , client_args .yaml )
221
+ )
222
+ else :
223
+ client .create_master (
224
+ resource_requests = client_args .master_resource_request ,
225
+ resource_limits = client_args .master_resource_limit ,
226
+ args = container_args ,
227
+ pod_priority = client_args .master_pod_priority ,
228
+ image_pull_policy = client_args .image_pull_policy ,
229
+ restart_policy = client_args .restart_policy ,
230
+ volume = client_args .volume ,
231
+ envs = parse_envs (client_args .envs ),
232
+ )
233
+ logger .info (
234
+ "ElasticDL job %s was successfully submitted. "
235
+ "The master pod is: %s."
236
+ % (client_args .job_name , client .get_master_pod_name ())
237
+ )
238
+
239
+
98
240
def _get_docker_client (docker_base_url , docker_tlscert , docker_tlskey ):
99
241
if docker_tlscert and docker_tlskey :
100
242
tls_config = docker .tls .TLSConfig (
0 commit comments