11import os
22import re
3+ from collections import defaultdict
34from pathlib import Path
45from typing import Any , Dict , List , Optional , Tuple , TypedDict , Union
56
67from kubernetes_asyncio .client .models .v1_job import V1Job
8+ from kubernetes_asyncio .client .models .v1_pod import V1Pod
79from kubernetes_asyncio .client .rest import ApiException
810from model_engine_server .common .config import hmi_config
911from model_engine_server .common .dtos .batch_jobs import CreateDockerImageBatchJobResourceRequests
1719)
1820from model_engine_server .infra .gateways .resources .k8s_endpoint_resource_delegate import (
1921 get_kubernetes_batch_client ,
22+ get_kubernetes_core_client ,
2023 load_k8s_yaml ,
2124 maybe_load_kube_config ,
2225)
@@ -84,7 +87,7 @@ def _k8s_job_name_from_id(job_id: str):
8487 return f"launch-di-batch-job-{ job_id } "
8588
8689
87- def _parse_job_status_from_k8s_obj (job : V1Job ) -> BatchJobStatus :
90+ def _parse_job_status_from_k8s_obj (job : V1Job , pods : List [ V1Pod ] ) -> BatchJobStatus :
8891 status = job .status
8992 # these counts are the number of pods in some given status
9093 if status .failed is not None and status .failed > 0 :
@@ -94,10 +97,30 @@ def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus:
9497 if status .ready is not None and status .ready > 0 :
9598 return BatchJobStatus .RUNNING # empirically this doesn't happen
9699 if status .active is not None and status .active > 0 :
97- return BatchJobStatus .RUNNING # TODO this might be a mix of pending and running
100+ for pod in pods :
101+ # In case there are multiple pods for a given job (e.g. if a pod gets shut down)
102+ # let's interpret the job as running if any of the pods are running
103+ # I haven't empirically seen this, but guard against it just in case.
104+ if pod .status .phase == "Running" :
105+ return BatchJobStatus .RUNNING
106+ return BatchJobStatus .PENDING
98107 return BatchJobStatus .PENDING
99108
100109
110+ def make_job_id_to_pods_mapping (pods : List [V1Pod ]) -> defaultdict :
111+ """
112+ Returns a defaultdict mapping job IDs to pods
113+ """
114+ job_id_to_pods_mapping = defaultdict (list )
115+ for pod in pods :
116+ job_id = pod .metadata .labels .get (LAUNCH_JOB_ID_LABEL_SELECTOR )
117+ if job_id is not None :
118+ job_id_to_pods_mapping [job_id ].append (pod )
119+ else :
120+ logger .warning (f"Pod { pod .metadata .name } has no job ID label" )
121+ return job_id_to_pods_mapping
122+
123+
101124class LiveDockerImageBatchJobGateway (DockerImageBatchJobGateway ):
102125 def __init__ (self ):
103126 pass
@@ -282,10 +305,21 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker
282305 logger .exception ("Got an exception when trying to read the Job" )
283306 raise EndpointResourceInfraException from exc
284307
308+ core_client = get_kubernetes_core_client ()
309+ try :
310+ pods = await core_client .list_namespaced_pod (
311+ namespace = hmi_config .endpoint_namespace ,
312+ label_selector = f"{ LAUNCH_JOB_ID_LABEL_SELECTOR } ={ batch_job_id } " ,
313+ )
314+ except ApiException as exc :
315+ logger .exception ("Got an exception when trying to read pods for the Job" )
316+ raise EndpointResourceInfraException from exc
317+ # This pod list isn't always needed, but it's simpler code-wise to always make the request
318+
285319 job_labels = job .metadata .labels
286320 annotations = job .metadata .annotations
287321
288- status = _parse_job_status_from_k8s_obj (job )
322+ status = _parse_job_status_from_k8s_obj (job , pods . items )
289323
290324 return DockerImageBatchJob (
291325 id = batch_job_id ,
@@ -309,6 +343,19 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
309343 logger .exception ("Got an exception when trying to list the Jobs" )
310344 raise EndpointResourceInfraException from exc
311345
346+ core_client = get_kubernetes_core_client ()
347+ try :
348+ pods = await core_client .list_namespaced_pod (
349+ namespace = hmi_config .endpoint_namespace ,
350+ label_selector = f"{ OWNER_LABEL_SELECTOR } ={ owner } ,job-name" , # get only pods associated with a job
351+ )
352+ except ApiException as exc :
353+ logger .exception ("Got an exception when trying to read pods for the Job" )
354+ raise EndpointResourceInfraException from exc
355+
356+ # Join jobs + pods
357+ pods_per_job = make_job_id_to_pods_mapping (pods .items )
358+
312359 return [
313360 DockerImageBatchJob (
314361 id = job .metadata .labels .get (LAUNCH_JOB_ID_LABEL_SELECTOR ),
@@ -317,7 +364,9 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
317364 created_at = job .metadata .creation_timestamp ,
318365 completed_at = job .status .completion_time ,
319366 annotations = job .metadata .annotations ,
320- status = _parse_job_status_from_k8s_obj (job ),
367+ status = _parse_job_status_from_k8s_obj (
368+ job , pods_per_job [job .metadata .labels .get (LAUNCH_JOB_ID_LABEL_SELECTOR )]
369+ ),
321370 )
322371 for job in jobs .items
323372 ]
0 commit comments