Skip to content

Commit c86357f

Browse files
Properly return PENDING status for docker image batch jobs/fine tune jobs (#318)
1 parent daa8f0a commit c86357f

File tree

5 files changed

+351
-17
lines changed

5 files changed

+351
-17
lines changed

model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import (
1111
LAUNCH_JOB_ID_LABEL_SELECTOR,
1212
_parse_job_status_from_k8s_obj,
13+
make_job_id_to_pods_mapping,
1314
)
1415
from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import (
1516
get_kubernetes_batch_client,
17+
get_kubernetes_core_client,
1618
load_k8s_yaml,
1719
maybe_load_kube_config,
1820
)
@@ -97,14 +99,30 @@ async def list_jobs(
9799
logger.exception("Got an exception when trying to list the Jobs")
98100
raise EndpointResourceInfraException from exc
99101

102+
core_client = get_kubernetes_core_client()
103+
104+
try:
105+
label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner},job-name"
106+
pods = await core_client.list_namespaced_pod(
107+
namespace=hmi_config.endpoint_namespace,
108+
label_selector=label_selector,
109+
)
110+
except ApiException as exc:
111+
logger.exception("Got an exception when trying to list the Pods")
112+
raise EndpointResourceInfraException from exc
113+
114+
pods_per_job = make_job_id_to_pods_mapping(pods.items)
115+
100116
return [
101117
DockerImageBatchJob(
102118
id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR),
103119
created_by=job.metadata.labels.get("created_by"),
104120
owner=job.metadata.labels.get("owner"),
105121
created_at=job.metadata.creation_timestamp,
106122
completed_at=job.status.completion_time,
107-
status=_parse_job_status_from_k8s_obj(job),
123+
status=_parse_job_status_from_k8s_obj(
124+
job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)]
125+
),
108126
)
109127
for job in jobs.items
110128
]

model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import re
3+
from collections import defaultdict
34
from pathlib import Path
45
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
56

67
from kubernetes_asyncio.client.models.v1_job import V1Job
8+
from kubernetes_asyncio.client.models.v1_pod import V1Pod
79
from kubernetes_asyncio.client.rest import ApiException
810
from model_engine_server.common.config import hmi_config
911
from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests
@@ -17,6 +19,7 @@
1719
)
1820
from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import (
1921
get_kubernetes_batch_client,
22+
get_kubernetes_core_client,
2023
load_k8s_yaml,
2124
maybe_load_kube_config,
2225
)
@@ -84,7 +87,7 @@ def _k8s_job_name_from_id(job_id: str):
8487
return f"launch-di-batch-job-{job_id}"
8588

8689

87-
def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus:
90+
def _parse_job_status_from_k8s_obj(job: V1Job, pods: List[V1Pod]) -> BatchJobStatus:
8891
status = job.status
8992
# these counts are the number of pods in some given status
9093
if status.failed is not None and status.failed > 0:
@@ -94,10 +97,30 @@ def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus:
9497
if status.ready is not None and status.ready > 0:
9598
return BatchJobStatus.RUNNING # empirically this doesn't happen
9699
if status.active is not None and status.active > 0:
97-
return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running
100+
for pod in pods:
101+
# In case there are multiple pods for a given job (e.g. if a pod gets shut down)
102+
# let's interpret the job as running if any of the pods are running
103+
# I haven't empirically seen this, but guard against it just in case.
104+
if pod.status.phase == "Running":
105+
return BatchJobStatus.RUNNING
106+
return BatchJobStatus.PENDING
98107
return BatchJobStatus.PENDING
99108

100109

110+
def make_job_id_to_pods_mapping(pods: List[V1Pod]) -> defaultdict:
111+
"""
112+
Returns a defaultdict mapping job IDs to pods
113+
"""
114+
job_id_to_pods_mapping = defaultdict(list)
115+
for pod in pods:
116+
job_id = pod.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)
117+
if job_id is not None:
118+
job_id_to_pods_mapping[job_id].append(pod)
119+
else:
120+
logger.warning(f"Pod {pod.metadata.name} has no job ID label")
121+
return job_id_to_pods_mapping
122+
123+
101124
class LiveDockerImageBatchJobGateway(DockerImageBatchJobGateway):
102125
def __init__(self):
103126
pass
@@ -282,10 +305,21 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker
282305
logger.exception("Got an exception when trying to read the Job")
283306
raise EndpointResourceInfraException from exc
284307

308+
core_client = get_kubernetes_core_client()
309+
try:
310+
pods = await core_client.list_namespaced_pod(
311+
namespace=hmi_config.endpoint_namespace,
312+
label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}",
313+
)
314+
except ApiException as exc:
315+
logger.exception("Got an exception when trying to read pods for the Job")
316+
raise EndpointResourceInfraException from exc
317+
# This pod list isn't always needed, but it's simpler code-wise to always make the request
318+
285319
job_labels = job.metadata.labels
286320
annotations = job.metadata.annotations
287321

288-
status = _parse_job_status_from_k8s_obj(job)
322+
status = _parse_job_status_from_k8s_obj(job, pods.items)
289323

290324
return DockerImageBatchJob(
291325
id=batch_job_id,
@@ -309,6 +343,19 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
309343
logger.exception("Got an exception when trying to list the Jobs")
310344
raise EndpointResourceInfraException from exc
311345

346+
core_client = get_kubernetes_core_client()
347+
try:
348+
pods = await core_client.list_namespaced_pod(
349+
namespace=hmi_config.endpoint_namespace,
350+
label_selector=f"{OWNER_LABEL_SELECTOR}={owner},job-name", # get only pods associated with a job
351+
)
352+
except ApiException as exc:
353+
logger.exception("Got an exception when trying to read pods for the Job")
354+
raise EndpointResourceInfraException from exc
355+
356+
# Join jobs + pods
357+
pods_per_job = make_job_id_to_pods_mapping(pods.items)
358+
312359
return [
313360
DockerImageBatchJob(
314361
id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR),
@@ -317,7 +364,9 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
317364
created_at=job.metadata.creation_timestamp,
318365
completed_at=job.status.completion_time,
319366
annotations=job.metadata.annotations,
320-
status=_parse_job_status_from_k8s_obj(job),
367+
status=_parse_job_status_from_k8s_obj(
368+
job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)]
369+
),
321370
)
322371
for job in jobs.items
323372
]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Various fake k8s objects to be used in mocking out the python k8s api client
2+
# Only classes are defined here. If you need to add various fields to the classes, please do so here.
3+
4+
from dataclasses import dataclass, field
5+
from datetime import datetime
6+
from typing import List, Optional
7+
8+
9+
@dataclass
10+
class FakeK8sV1ObjectMeta:
11+
name: str = "fake_name"
12+
namespace: str = "fake_namespace"
13+
annotations: dict = field(default_factory=dict)
14+
labels: dict = field(default_factory=dict)
15+
creation_timestamp: datetime = datetime(2021, 1, 1, 0, 0, 0, 0)
16+
# TODO: everything else
17+
18+
19+
@dataclass
20+
class FakeK8sV1PodStatus:
21+
phase: str = "Running"
22+
# TODO: everything else
23+
24+
25+
@dataclass
26+
class FakeK8sV1JobStatus:
27+
active: int = 0
28+
succeeded: int = 0
29+
failed: int = 0
30+
ready: int = 0
31+
terminating: int = 0
32+
completion_time: Optional[datetime] = None
33+
34+
35+
@dataclass
36+
class FakeK8sV1Job:
37+
metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta()
38+
status: FakeK8sV1JobStatus = FakeK8sV1JobStatus()
39+
# TODO: spec, api_version, kind
40+
41+
42+
@dataclass
43+
class FakeK8sV1JobList:
44+
items: List[FakeK8sV1Job] = field(default_factory=list)
45+
46+
47+
@dataclass
48+
class FakeK8sV1Pod:
49+
metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta()
50+
status: FakeK8sV1PodStatus = FakeK8sV1PodStatus()
51+
# TODO: spec, api_version, kind
52+
53+
54+
@dataclass
55+
class FakeK8sV1PodList:
56+
items: List[FakeK8sV1Pod] = field(default_factory=list)
57+
58+
59+
@dataclass
60+
class FakeK8sEnvVar:
61+
name: str
62+
value: str
63+
64+
65+
@dataclass
66+
class FakeK8sDeploymentContainer:
67+
env: List[FakeK8sEnvVar]

model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass
21
from typing import Any, Dict, List
32
from unittest.mock import AsyncMock, Mock, patch
43

@@ -25,21 +24,11 @@
2524
DictStrStr,
2625
ResourceArguments,
2726
)
27+
from tests.unit.infra.gateways.k8s_fake_objects import FakeK8sDeploymentContainer, FakeK8sEnvVar
2828

2929
MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate"
3030

3131

32-
@dataclass
33-
class FakeK8sEnvVar:
34-
name: str
35-
value: str
36-
37-
38-
@dataclass
39-
class FakeK8sDeploymentContainer:
40-
env: List[FakeK8sEnvVar]
41-
42-
4332
@pytest.fixture
4433
def mock_get_kubernetes_cluster_version():
4534
mock_version = "1.26"

0 commit comments

Comments
 (0)