Skip to content

Commit 53f8026

Browse files
authored
Step Status Refresh Functionality + Kubernetes Orchestrator Implementation (#3735)
* first checkpoint * new changes * fixes * new changes * small change * deprecate old method * new changes * missing import * listen to events * linting * loop optimization * changed the deprecation warning * new condition * switching to jobs * formatting * handling the store * not allowing finished steps to be updated * docstrings * label param name * removed unused function * comment and formatting * renamed function * moved steps outside * removed unused input * additional check * docstrings and formatting * removed status checks * orchestrator pod updates * new check
1 parent 7ca69ea commit 53f8026

File tree

10 files changed

+382
-130
lines changed

10 files changed

+382
-130
lines changed

src/zenml/cli/pipeline.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,16 +604,27 @@ def delete_pipeline_run(
604604

605605
@runs.command("refresh")
606606
@click.argument("run_name_or_id", type=str, required=True)
607-
def refresh_pipeline_run(run_name_or_id: str) -> None:
607+
@click.option(
608+
"--include-steps",
609+
is_flag=True,
610+
default=False,
611+
help="Also refresh the status of individual steps.",
612+
)
613+
def refresh_pipeline_run(
614+
run_name_or_id: str, include_steps: bool = False
615+
) -> None:
608616
"""Refresh the status of a pipeline run.
609617
610618
Args:
611619
run_name_or_id: The name or ID of the pipeline run to refresh.
620+
include_steps: If True, also refresh the status of individual steps.
612621
"""
613622
try:
614623
# Fetch and update the run
615624
run = Client().get_pipeline_run(name_id_or_prefix=run_name_or_id)
616-
run.refresh_run_status()
625+
run_utils.refresh_run_status(
626+
run=run, include_step_updates=include_steps
627+
)
617628

618629
except KeyError as e:
619630
cli_utils.error(str(e))

src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -807,14 +807,20 @@ def get_pipeline_run_metadata(
807807
settings=settings,
808808
)
809809

810-
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
810+
def fetch_status(
811+
self, run: "PipelineRunResponse", include_steps: bool = False
812+
) -> Tuple[
813+
Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
814+
]:
811815
"""Refreshes the status of a specific pipeline run.
812816
813817
Args:
814818
run: The run that was executed by this orchestrator.
819+
include_steps: Whether to fetch steps
815820
816821
Returns:
817-
the actual status of the pipeline job.
822+
A tuple of (pipeline_status, step_statuses_dict).
823+
Step statuses are not supported for SageMaker, so step_statuses_dict will always be None.
818824
819825
Raises:
820826
AssertionError: If the run was not executed by to this orchestrator.
@@ -855,18 +861,21 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
855861
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
856862
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribePipelineExecution.html
857863
if status == "Executing":
858-
return ExecutionStatus.RUNNING
864+
pipeline_status = ExecutionStatus.RUNNING
859865
elif status == "Stopping":
860-
return ExecutionStatus.STOPPING
866+
pipeline_status = ExecutionStatus.STOPPING
861867
elif status == "Stopped":
862-
return ExecutionStatus.STOPPED
868+
pipeline_status = ExecutionStatus.STOPPED
863869
elif status == "Failed":
864-
return ExecutionStatus.FAILED
870+
pipeline_status = ExecutionStatus.FAILED
865871
elif status == "Succeeded":
866-
return ExecutionStatus.COMPLETED
872+
pipeline_status = ExecutionStatus.COMPLETED
867873
else:
868874
raise ValueError("Unknown status for the pipeline execution.")
869875

876+
# SageMaker doesn't support step-level status fetching yet
877+
return pipeline_status, None
878+
870879
def compute_metadata(
871880
self,
872881
execution_arn: str,

src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,20 @@ def get_pipeline_run_metadata(
451451
)
452452
return {}
453453

454-
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
454+
def fetch_status(
455+
self, run: "PipelineRunResponse", include_steps: bool = False
456+
) -> Tuple[
457+
Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
458+
]:
455459
"""Refreshes the status of a specific pipeline run.
456460
457461
Args:
458462
run: The run that was executed by this orchestrator.
463+
include_steps: Whether to fetch steps (not supported for AzureML).
459464
460465
Returns:
461-
the actual status of the pipeline execution.
466+
A tuple of (pipeline_status, step_statuses_dict).
467+
Step statuses are not supported for AzureML, so step_statuses_dict will always be None.
462468
463469
Raises:
464470
AssertionError: If the run was not executed by to this orchestrator.
@@ -512,23 +518,23 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
512518
"Preparing",
513519
"Queued",
514520
]:
515-
return ExecutionStatus.INITIALIZING
521+
pipeline_status = ExecutionStatus.INITIALIZING
516522
elif status in ["Running", "Finalizing"]:
517-
return ExecutionStatus.RUNNING
523+
pipeline_status = ExecutionStatus.RUNNING
518524
elif status == "CancelRequested":
519-
return ExecutionStatus.STOPPING
525+
pipeline_status = ExecutionStatus.STOPPING
520526
elif status == "Canceled":
521-
return ExecutionStatus.STOPPED
522-
elif status in [
523-
"Failed",
524-
"NotResponding",
525-
]:
526-
return ExecutionStatus.FAILED
527+
pipeline_status = ExecutionStatus.STOPPED
528+
elif status in ["Failed", "NotResponding"]:
529+
pipeline_status = ExecutionStatus.FAILED
527530
elif status == "Completed":
528-
return ExecutionStatus.COMPLETED
531+
pipeline_status = ExecutionStatus.COMPLETED
529532
else:
530533
raise ValueError("Unknown status for the pipeline job.")
531534

535+
# AzureML doesn't support step-level status fetching yet
536+
return pipeline_status, None
537+
532538
def compute_metadata(self, job: Any) -> Dict[str, MetadataType]:
533539
"""Generate run metadata based on the generated AzureML PipelineJob.
534540

src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -893,14 +893,20 @@ def _configure_container_resources(
893893

894894
return dynamic_component
895895

896-
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
896+
def fetch_status(
897+
self, run: "PipelineRunResponse", include_steps: bool = False
898+
) -> Tuple[
899+
Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
900+
]:
897901
"""Refreshes the status of a specific pipeline run.
898902
899903
Args:
900904
run: The run that was executed by this orchestrator.
905+
include_steps: Whether to fetch steps.
901906
902907
Returns:
903-
the actual status of the pipeline job.
908+
A tuple of (pipeline_status, step_statuses_dict).
909+
Step statuses are not supported for Vertex, so step_statuses_dict will always be None.
904910
905911
Raises:
906912
AssertionError: If the run was not executed by to this orchestrator.
@@ -943,28 +949,31 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
943949
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
944950
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#
945951
if status == PipelineState.PIPELINE_STATE_UNSPECIFIED:
946-
return run.status
952+
pipeline_status = run.status
947953
elif status in [
948954
PipelineState.PIPELINE_STATE_QUEUED,
949955
PipelineState.PIPELINE_STATE_PENDING,
950956
]:
951-
return ExecutionStatus.INITIALIZING
957+
pipeline_status = ExecutionStatus.INITIALIZING
952958
elif status in [
953959
PipelineState.PIPELINE_STATE_RUNNING,
954960
PipelineState.PIPELINE_STATE_PAUSED,
955961
]:
956-
return ExecutionStatus.RUNNING
962+
pipeline_status = ExecutionStatus.RUNNING
957963
elif status == PipelineState.PIPELINE_STATE_SUCCEEDED:
958-
return ExecutionStatus.COMPLETED
964+
pipeline_status = ExecutionStatus.COMPLETED
959965
elif status == PipelineState.PIPELINE_STATE_CANCELLING:
960-
return ExecutionStatus.STOPPING
966+
pipeline_status = ExecutionStatus.STOPPING
961967
elif status == PipelineState.PIPELINE_STATE_CANCELLED:
962-
return ExecutionStatus.STOPPED
968+
pipeline_status = ExecutionStatus.STOPPED
963969
elif status == PipelineState.PIPELINE_STATE_FAILED:
964-
return ExecutionStatus.FAILED
970+
pipeline_status = ExecutionStatus.FAILED
965971
else:
966972
raise ValueError("Unknown status for the pipeline job.")
967973

974+
# Vertex doesn't support step-level status fetching yet
975+
return pipeline_status, None
976+
968977
def compute_metadata(
969978
self, job: aiplatform.PipelineJob
970979
) -> Dict[str, MetadataType]:

src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from zenml.constants import (
5050
METADATA_ORCHESTRATOR_RUN_ID,
5151
)
52-
from zenml.enums import StackComponentType
52+
from zenml.enums import ExecutionStatus, StackComponentType
5353
from zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor import (
5454
KubernetesOrchestratorConfig,
5555
KubernetesOrchestratorSettings,
@@ -785,6 +785,187 @@ def _stop_run(
785785
f"No running step jobs found for pipeline run with ID: {run.id}"
786786
)
787787

788+
def fetch_status(
789+
self, run: "PipelineRunResponse", include_steps: bool = False
790+
) -> Tuple[
791+
Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
792+
]:
793+
"""Refreshes the status of a specific pipeline run.
794+
795+
Args:
796+
run: The run that was executed by this orchestrator.
797+
include_steps: If True, also fetch the status of individual steps.
798+
799+
Returns:
800+
A tuple of (pipeline_status, step_statuses).
801+
If include_steps is False, step_statuses will be None.
802+
If include_steps is True, step_statuses will be a dict (possibly empty).
803+
804+
Raises:
805+
ValueError: If the orchestrator run ID cannot be found or if the
806+
stack components are not accessible.
807+
"""
808+
# Get the orchestrator run ID which corresponds to the orchestrator pod name
809+
orchestrator_run_id = run.orchestrator_run_id
810+
if not orchestrator_run_id:
811+
raise ValueError(
812+
"Cannot determine orchestrator run ID for the run. "
813+
"Unable to fetch the status."
814+
)
815+
816+
# Check the orchestrator pod status (only if run is not finished)
817+
if not run.status.is_finished:
818+
orchestrator_pod_phase = self._check_pod_status(
819+
pod_name=orchestrator_run_id,
820+
)
821+
pipeline_status = self._map_pod_phase_to_execution_status(
822+
orchestrator_pod_phase
823+
)
824+
else:
825+
# Run is already finished, don't change status
826+
pipeline_status = None
827+
828+
step_statuses = None
829+
if include_steps:
830+
step_statuses = self._fetch_step_statuses(run)
831+
832+
return pipeline_status, step_statuses
833+
834+
def _check_pod_status(
835+
self,
836+
pod_name: str,
837+
) -> kube_utils.PodPhase:
838+
"""Check pod status and handle deletion scenarios for both orchestrator and step pods.
839+
840+
This method should only be called for non-finished pipeline runs/steps.
841+
842+
Args:
843+
pod_name: The name of the pod to check.
844+
845+
Returns:
846+
The pod phase if the pod exists, or PodPhase.FAILED if pod was deleted.
847+
"""
848+
pod = kube_utils.get_pod(
849+
core_api=self._k8s_core_api,
850+
pod_name=pod_name,
851+
namespace=self.config.kubernetes_namespace,
852+
)
853+
854+
if pod and pod.status and pod.status.phase:
855+
try:
856+
return kube_utils.PodPhase(pod.status.phase)
857+
except ValueError:
858+
# Handle unknown pod phases
859+
logger.warning(
860+
f"Unknown pod phase for pod {pod_name}: {pod.status.phase}"
861+
)
862+
return kube_utils.PodPhase.UNKNOWN
863+
else:
864+
logger.warning(
865+
f"Can't fetch the status of pod {pod_name} "
866+
f"in namespace {self.config.kubernetes_namespace}."
867+
)
868+
return kube_utils.PodPhase.UNKNOWN
869+
870+
def _map_pod_phase_to_execution_status(
871+
self, pod_phase: kube_utils.PodPhase
872+
) -> Optional[ExecutionStatus]:
873+
"""Map Kubernetes pod phase to ZenML execution status.
874+
875+
Args:
876+
pod_phase: The Kubernetes pod phase.
877+
878+
Returns:
879+
The corresponding ZenML execution status.
880+
"""
881+
if pod_phase == kube_utils.PodPhase.PENDING:
882+
return ExecutionStatus.INITIALIZING
883+
elif pod_phase == kube_utils.PodPhase.RUNNING:
884+
return ExecutionStatus.RUNNING
885+
elif pod_phase == kube_utils.PodPhase.SUCCEEDED:
886+
return ExecutionStatus.COMPLETED
887+
elif pod_phase == kube_utils.PodPhase.FAILED:
888+
return ExecutionStatus.FAILED
889+
else: # UNKNOWN - no update
890+
return None
891+
892+
def _map_job_status_to_execution_status(
893+
self, job: k8s_client.V1Job
894+
) -> Optional[ExecutionStatus]:
895+
"""Map Kubernetes job status to ZenML execution status.
896+
897+
Args:
898+
job: The Kubernetes job.
899+
900+
Returns:
901+
The corresponding ZenML execution status, or None if no clear status.
902+
"""
903+
# Check job conditions first
904+
if job.status and job.status.conditions:
905+
for condition in job.status.conditions:
906+
if condition.type == "Complete" and condition.status == "True":
907+
return ExecutionStatus.COMPLETED
908+
elif condition.type == "Failed" and condition.status == "True":
909+
return ExecutionStatus.FAILED
910+
911+
# Return None if no clear status - don't update
912+
return None
913+
914+
def _fetch_step_statuses(
915+
self, run: "PipelineRunResponse"
916+
) -> Dict[str, ExecutionStatus]:
917+
"""Fetch the statuses of individual pipeline steps.
918+
919+
Args:
920+
run: The pipeline run response.
921+
922+
Returns:
923+
A dictionary mapping step names to their execution statuses.
924+
"""
925+
step_statuses = {}
926+
927+
# Query all jobs for this run and match them to steps
928+
label_selector = f"run_id={kube_utils.sanitize_label(str(run.id))}"
929+
930+
try:
931+
jobs = self._k8s_batch_api.list_namespaced_job(
932+
namespace=self.config.kubernetes_namespace,
933+
label_selector=label_selector,
934+
)
935+
except Exception as e:
936+
logger.warning(f"Failed to list jobs for run {run.id}: {e}")
937+
return {}
938+
939+
# Fetch the steps from the run response
940+
steps_dict = run.steps
941+
942+
for job in jobs.items:
943+
# Extract step name from job labels
944+
if not job.metadata or not job.metadata.labels:
945+
continue
946+
947+
step_name = job.metadata.labels.get("step_name")
948+
if not step_name:
949+
continue
950+
951+
# Check if this step is already finished
952+
step_response = steps_dict.get(step_name, None)
953+
954+
# If the step is not in the run response yet, skip, we can't update
955+
if step_response is None:
956+
continue
957+
958+
# If the step is already in a finished state, skip
959+
if step_response and step_response.status.is_finished:
960+
continue
961+
962+
# Check job status and map to execution status
963+
execution_status = self._map_job_status_to_execution_status(job)
964+
if execution_status is not None:
965+
step_statuses[step_name] = execution_status
966+
967+
return step_statuses
968+
788969
def get_pipeline_run_metadata(
789970
self, run_id: UUID
790971
) -> Dict[str, "MetadataType"]:

0 commit comments

Comments
 (0)